mirror of
https://github.com/langgenius/dify.git
synced 2026-01-02 20:47:20 +00:00
Compare commits
1 Commits
feat/node-
...
feat/updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d34db912f |
@@ -7,6 +7,5 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
source /home/vscode/.bashrc
|
||||
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@@ -50,9 +50,6 @@ jobs:
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@@ -9,6 +9,5 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
|
||||
|
||||
1
.github/workflows/style.yml
vendored
1
.github/workflows/style.yml
vendored
@@ -37,7 +37,6 @@ jobs:
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
|
||||
|
||||
3
.github/workflows/vdb-tests.yml
vendored
3
.github/workflows/vdb-tests.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@@ -56,27 +56,20 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
OPENDAL_SCHEME=fs
|
||||
OPENDAL_FS_ROOT=storage
|
||||
|
||||
# S3 Storage configuration
|
||||
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
|
||||
S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
@@ -86,7 +79,6 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
@@ -133,8 +125,8 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -285,7 +277,6 @@ VIKINGDB_SOCKET_TIMEOUT=30
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
@@ -304,7 +295,8 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
@@ -389,8 +381,6 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
@@ -423,7 +413,3 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
@@ -70,6 +70,7 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
27
api/app.py
27
api/app.py
@@ -1,30 +1,13 @@
|
||||
from libs import version_utils
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils, version_utils
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
|
||||
|
||||
def is_db_command():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
# create app
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
@@ -16,6 +17,15 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
@@ -88,14 +98,3 @@ def initialize_extensions(app: DifyApp):
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,51 +1,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
from .packaging import PackagingInfo
|
||||
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
||||
from .remote_settings_sources.apollo import ApolloSettingsSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
super().__init__(settings_cls)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
if not remote_source_name:
|
||||
return {}
|
||||
|
||||
remote_source: RemoteSettingsSource | None = None
|
||||
match remote_source_name:
|
||||
case RemoteSettingsSourceName.APOLLO:
|
||||
remote_source = ApolloSettingsSource(current_state)
|
||||
case _:
|
||||
logger.warning(f"Unsupported remote source: {remote_source_name}")
|
||||
return {}
|
||||
|
||||
d: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
|
||||
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
if field_value is not None:
|
||||
d[field_key] = field_value
|
||||
|
||||
return d
|
||||
from configs.deploy import DeploymentConfig
|
||||
from configs.enterprise import EnterpriseFeatureConfig
|
||||
from configs.extra import ExtraServiceConfig
|
||||
from configs.feature import FeatureConfig
|
||||
from configs.middleware import MiddlewareConfig
|
||||
from configs.packaging import PackagingInfo
|
||||
|
||||
|
||||
class DifyConfig(
|
||||
@@ -59,8 +19,6 @@ class DifyConfig(
|
||||
MiddlewareConfig,
|
||||
# Extra service configs
|
||||
ExtraServiceConfig,
|
||||
# Remote source configs
|
||||
RemoteSettingsSourceConfig,
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
@@ -77,20 +35,3 @@ class DifyConfig(
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
RemoteSettingsSourceFactory(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
@@ -439,17 +439,6 @@ class WorkflowConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow node execution
|
||||
"""
|
||||
|
||||
MAX_SUBMIT_COUNT: PositiveInt = Field(
|
||||
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
|
||||
default=100,
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -485,11 +474,6 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@@ -665,9 +649,14 @@ class IndexingConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
class VisionFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
|
||||
@@ -773,20 +762,19 @@ class FeatureConfig(
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
VisionFormatConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
LoggingConfig,
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
# hosted services config
|
||||
|
||||
@@ -1,69 +1,54 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from .vdb.upstash_config import UpstashConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.lindorm_config import LindormConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
STORAGE_TYPE: str = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
|
||||
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
|
||||
default="storage",
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,7 +73,7 @@ class KeywordStoreConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@@ -250,7 +235,6 @@ class MiddlewareConfig(
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
OpenDALStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaiduOBSStorageConfig(BaseSettings):
|
||||
class BaiduOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Baidu Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Huawei Cloud Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenDALStorageConfig(BaseSettings):
|
||||
OPENDAL_SCHEME: str = Field(
|
||||
default="fs",
|
||||
description="OpenDAL scheme.",
|
||||
)
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseSettings):
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseSettings):
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseSettings):
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CouchbaseConfig(BaseSettings):
|
||||
class CouchbaseConfig(BaseModel):
|
||||
"""
|
||||
Couchbase configs
|
||||
"""
|
||||
|
||||
@@ -21,14 +21,3 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: Optional[str] = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: Optional[bool] = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseSettings):
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for MyScale vector database
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseSettings):
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.1",
|
||||
default="0.13.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .apollo import ApolloSettingsSourceInfo
|
||||
from .base import RemoteSettingsSource
|
||||
from .enums import RemoteSettingsSourceName
|
||||
|
||||
|
||||
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
|
||||
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
|
||||
description="name of remote config source",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]
|
||||
@@ -1,55 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .client import ApolloClient
|
||||
|
||||
|
||||
class ApolloSettingsSourceInfo(BaseSettings):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
APOLLO_APP_ID: Optional[str] = Field(
|
||||
description="apollo app_id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CLUSTER: Optional[str] = Field(
|
||||
description="apollo cluster",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CONFIG_URL: Optional[str] = Field(
|
||||
description="apollo config url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_NAMESPACE: Optional[str] = Field(
|
||||
description="apollo namespace",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ApolloSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.client = ApolloClient(
|
||||
app_id=configs["APOLLO_APP_ID"],
|
||||
cluster=configs["APOLLO_CLUSTER"],
|
||||
config_url=configs["APOLLO_CONFIG_URL"],
|
||||
start_hot_update=False,
|
||||
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
|
||||
)
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
return field_value, field_name, False
|
||||
@@ -1,303 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from .python_3x import http_request, makedirs_wrapper
|
||||
from .utils import (
|
||||
CONFIGURATIONS,
|
||||
NAMESPACE_NAME,
|
||||
NOTIFICATION_ID,
|
||||
get_value_from_dict,
|
||||
init_ip,
|
||||
no_key_cache_key,
|
||||
signature,
|
||||
url_encode_wrapper,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApolloClient:
|
||||
def __init__(
|
||||
self,
|
||||
config_url,
|
||||
app_id,
|
||||
cluster="default",
|
||||
secret="",
|
||||
start_hot_update=True,
|
||||
change_listener=None,
|
||||
_notification_map=None,
|
||||
):
|
||||
# Core routing parameters
|
||||
self.config_url = config_url
|
||||
self.cluster = cluster
|
||||
self.app_id = app_id
|
||||
|
||||
# Non-core parameters
|
||||
self.ip = init_ip()
|
||||
self.secret = secret
|
||||
|
||||
# Check the parameter variables
|
||||
|
||||
# Private control variables
|
||||
self._cycle_time = 5
|
||||
self._stopping = False
|
||||
self._cache = {}
|
||||
self._no_key = {}
|
||||
self._hash = {}
|
||||
self._pull_timeout = 75
|
||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||
self._long_poll_thread = None
|
||||
self._change_listener = change_listener # "add" "delete" "update"
|
||||
if _notification_map is None:
|
||||
_notification_map = {"application": -1}
|
||||
self._notification_map = _notification_map
|
||||
self.last_release_key = None
|
||||
# Private startup method
|
||||
self._path_checker()
|
||||
if start_hot_update:
|
||||
self._start_hot_update()
|
||||
|
||||
# start the heartbeat thread
|
||||
heartbeat = threading.Thread(target=self._heart_beat)
|
||||
heartbeat.daemon = True
|
||||
heartbeat.start()
|
||||
|
||||
def get_json_from_net(self, namespace="application"):
|
||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||
)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"get_json_from_net load configs failed, body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
data = data["configurations"]
|
||||
return_data = {CONFIGURATIONS: data}
|
||||
return return_data
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in get_json_from_net")
|
||||
return None
|
||||
|
||||
def get_value(self, key, default_val=None, namespace="application"):
|
||||
try:
|
||||
# read memory configuration
|
||||
namespace_cache = self._cache.get(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
if no_key in self._no_key:
|
||||
return default_val
|
||||
|
||||
# read the network configuration
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
val = get_value_from_dict(namespace_data, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return val
|
||||
|
||||
# read the file configuration
|
||||
namespace_cache = self._get_local_cache(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_cache, namespace)
|
||||
return val
|
||||
|
||||
# If all of them are not obtained, the default value is returned
|
||||
# and the local cache is set to None
|
||||
self._set_local_cache_none(namespace, key)
|
||||
return default_val
|
||||
except Exception:
|
||||
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
|
||||
return default_val
|
||||
|
||||
# Set the key of a namespace to none, and do not set default val
|
||||
# to ensure the real-time correctness of the function call.
|
||||
# If the user does not have the same default val twice
|
||||
# and the default val is used here, there may be a problem.
|
||||
def _set_local_cache_none(self, namespace, key):
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
self._no_key[no_key] = key
|
||||
|
||||
def _start_hot_update(self):
|
||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||
# when the main thread is launched.
|
||||
self._long_poll_thread.daemon = True
|
||||
self._long_poll_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
logger.info("Stopping listener...")
|
||||
|
||||
# Call the set callback function, and if it is abnormal, try it out
|
||||
def _call_listener(self, namespace, old_kv, new_kv):
|
||||
if self._change_listener is None:
|
||||
return
|
||||
if old_kv is None:
|
||||
old_kv = {}
|
||||
if new_kv is None:
|
||||
new_kv = {}
|
||||
try:
|
||||
for key in old_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if new_value is None:
|
||||
# If newValue is empty, it means key, and the value is deleted.
|
||||
self._change_listener("delete", namespace, key, old_value)
|
||||
continue
|
||||
if new_value != old_value:
|
||||
self._change_listener("update", namespace, key, new_value)
|
||||
continue
|
||||
for key in new_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if old_value is None:
|
||||
self._change_listener("add", namespace, key, new_value)
|
||||
except BaseException as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _path_checker(self):
|
||||
if not os.path.isdir(self._cache_file_path):
|
||||
makedirs_wrapper(self._cache_file_path)
|
||||
|
||||
# update the local cache and file cache
|
||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||
# update the local cache
|
||||
self._cache[namespace] = namespace_data
|
||||
# update the file cache
|
||||
new_string = json.dumps(namespace_data)
|
||||
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
|
||||
if self._hash.get(namespace) == new_hash:
|
||||
pass
|
||||
else:
|
||||
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
|
||||
file_path.write_text(new_string)
|
||||
self._hash[namespace] = new_hash
|
||||
|
||||
# get the configuration from the local file
|
||||
def _get_local_cache(self, namespace="application"):
|
||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||
if os.path.isfile(cache_file_path):
|
||||
with open(cache_file_path) as f:
|
||||
result = json.loads(f.readline())
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _long_poll(self):
|
||||
notifications = []
|
||||
for key in self._cache:
|
||||
namespace_data = self._cache[key]
|
||||
notification_id = -1
|
||||
if NOTIFICATION_ID in namespace_data:
|
||||
notification_id = self._cache[key][NOTIFICATION_ID]
|
||||
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
|
||||
try:
|
||||
# if the length is 0 it is returned directly
|
||||
if len(notifications) == 0:
|
||||
return
|
||||
url = "{}/notifications/v2".format(self.config_url)
|
||||
params = {
|
||||
"appId": self.app_id,
|
||||
"cluster": self.cluster,
|
||||
"notifications": json.dumps(notifications, ensure_ascii=False),
|
||||
}
|
||||
param_str = url_encode_wrapper(params)
|
||||
url = url + "?" + param_str
|
||||
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
|
||||
http_code = code
|
||||
if http_code == 304:
|
||||
logger.debug("No change, loop...")
|
||||
return
|
||||
if http_code == 200:
|
||||
if not body:
|
||||
logger.error(f"_long_poll load configs failed,body is {body}")
|
||||
return
|
||||
data = json.loads(body)
|
||||
for entry in data:
|
||||
namespace = entry[NAMESPACE_NAME]
|
||||
n_id = entry[NOTIFICATION_ID]
|
||||
logger.info("%s has changes: notificationId=%d", namespace, n_id)
|
||||
self._get_net_and_set_local(namespace, n_id, call_change=True)
|
||||
return
|
||||
else:
|
||||
logger.warning("Sleep...")
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
if not namespace_data:
|
||||
return
|
||||
namespace_data[NOTIFICATION_ID] = n_id
|
||||
old_namespace = self._cache.get(namespace)
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
if self._change_listener is not None and call_change and old_namespace:
|
||||
old_kv = old_namespace.get(CONFIGURATIONS)
|
||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||
self._call_listener(namespace, old_kv, new_kv)
|
||||
|
||||
def _listener(self):
|
||||
logger.info("start long_poll")
|
||||
while not self._stopping:
|
||||
self._long_poll()
|
||||
time.sleep(self._cycle_time)
|
||||
logger.info("stopped, long_poll")
|
||||
|
||||
# add the need for endorsement to the header
|
||||
def _sign_headers(self, url):
|
||||
headers = {}
|
||||
if self.secret == "":
|
||||
return headers
|
||||
uri = url[len(self.config_url) : len(url)]
|
||||
time_unix_now = str(int(round(time.time() * 1000)))
|
||||
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
|
||||
headers["Timestamp"] = time_unix_now
|
||||
return headers
|
||||
|
||||
def _heart_beat(self):
|
||||
while not self._stopping:
|
||||
for namespace in self._notification_map:
|
||||
self._do_heart_beat(namespace)
|
||||
time.sleep(60 * 10) # 10分钟
|
||||
|
||||
def _do_heart_beat(self, namespace):
|
||||
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"_do_heart_beat load configs failed,body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
if self.last_release_key == data["releaseKey"]:
|
||||
return None
|
||||
self.last_release_key = data["releaseKey"]
|
||||
data = data["configurations"]
|
||||
self._update_cache_and_file(data, namespace)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in _do_heart_beat")
|
||||
return None
|
||||
|
||||
def get_all_dicts(self, namespace):
|
||||
namespace_data = self._cache.get(namespace)
|
||||
if namespace_data is None:
|
||||
net_namespace_data = self.get_json_from_net(namespace)
|
||||
if not net_namespace_data:
|
||||
return namespace_data
|
||||
namespace_data = net_namespace_data.get(CONFIGURATIONS)
|
||||
if namespace_data:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return namespace_data
|
||||
@@ -1,41 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from urllib import parse
|
||||
from urllib.error import HTTPError
|
||||
|
||||
# Create an SSL context that allows for a lower level of security
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Create an opener object and pass in a custom SSL context
|
||||
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
|
||||
|
||||
urllib.request.install_opener(opener)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def http_request(url, timeout, headers={}):
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=headers)
|
||||
res = urllib.request.urlopen(request, timeout=timeout)
|
||||
body = res.read().decode("utf-8")
|
||||
return res.code, body
|
||||
except HTTPError as e:
|
||||
if e.code == 304:
|
||||
logger.warning("http_request error,code is 304, maybe you should check secret")
|
||||
return 304, None
|
||||
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
|
||||
raise e
|
||||
|
||||
|
||||
def url_encode(params):
|
||||
return parse.urlencode(params)
|
||||
|
||||
|
||||
def makedirs_wrapper(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -1,51 +0,0 @@
|
||||
import hashlib
|
||||
import socket
|
||||
|
||||
from .python_3x import url_encode
|
||||
|
||||
# define constants
|
||||
CONFIGURATIONS = "configurations"
|
||||
NOTIFICATION_ID = "notificationId"
|
||||
NAMESPACE_NAME = "namespaceName"
|
||||
|
||||
|
||||
# add timestamps uris and keys
|
||||
def signature(timestamp, uri, secret):
|
||||
import base64
|
||||
import hmac
|
||||
|
||||
string_to_sign = "" + timestamp + "\n" + uri
|
||||
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
|
||||
return base64.b64encode(hmac_code).decode()
|
||||
|
||||
|
||||
def url_encode_wrapper(params):
|
||||
return url_encode(params)
|
||||
|
||||
|
||||
def no_key_cache_key(namespace, key):
|
||||
return "{}{}{}".format(namespace, len(namespace), key)
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache, key):
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
return None
|
||||
if key in kv_data:
|
||||
return kv_data[key]
|
||||
return None
|
||||
|
||||
|
||||
def init_ip():
|
||||
ip = ""
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 53))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
if s:
|
||||
s.close()
|
||||
return ip
|
||||
@@ -1,15 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
class RemoteSettingsSource:
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
return value
|
||||
@@ -1,5 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RemoteSettingsSourceName(StrEnum):
|
||||
APOLLO = "apollo"
|
||||
@@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f"AGENT.{app_model.id}",
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
# get decrypted parameters
|
||||
@@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
@@ -27,7 +26,7 @@ class TraceAppConfigApi(Resource):
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
raise e
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -49,7 +48,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigCheckError()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
raise e
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -69,7 +68,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
raise e
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -86,7 +85,7 @@ class TraceAppConfigApi(Resource):
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
raise e
|
||||
|
||||
|
||||
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")
|
||||
|
||||
@@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
if document.indexing_status == "completed":
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception(f"Failed to retry document, document id: {document_id}")
|
||||
continue
|
||||
# retry document
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
@@ -21,17 +20,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
|
||||
@@ -4,7 +4,6 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import AppIconUrlField
|
||||
from libs.login import login_required
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
@@ -13,8 +12,6 @@ app_fields = {
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -59,9 +58,6 @@ class FileApi(Resource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if source == "datasets" and not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if source not in ("datasets", None):
|
||||
source = None
|
||||
|
||||
|
||||
@@ -368,7 +368,6 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
labels=args["labels"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -36,29 +36,6 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
@@ -67,17 +44,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -19,10 +19,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -33,7 +31,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -130,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@@ -320,7 +318,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
@@ -329,22 +327,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
@@ -379,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -392,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
@@ -403,29 +387,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
@@ -443,7 +404,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -28,39 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -69,7 +36,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,29 +44,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,29 +44,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -30,35 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -70,20 +41,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
|
||||
@@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
@@ -35,8 +34,7 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
| QueueWorkflowFailedEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
|
||||
@@ -15,10 +15,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -28,7 +26,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -109,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -260,44 +258,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
@@ -305,7 +266,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
@@ -338,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -351,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
@@ -361,30 +345,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
@@ -394,6 +354,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -405,7 +366,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
|
||||
@@ -8,10 +8,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -20,7 +18,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -28,7 +25,6 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
@@ -36,10 +32,8 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -182,12 +176,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
@@ -263,36 +253,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
@@ -422,36 +382,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.error,
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
start_index=event.start_index,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@@ -25,14 +25,12 @@ class QueueEvent(StrEnum):
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
|
||||
ITERATION_START = "iteration_start"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_SUCCEEDED = "node_succeeded"
|
||||
NODE_FAILED = "node_failed"
|
||||
NODE_EXCEPTION = "node_exception"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@@ -43,7 +41,6 @@ class QueueEvent(StrEnum):
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -116,6 +113,18 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
@@ -240,17 +249,6 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
exceptions_count: int
|
||||
|
||||
|
||||
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@@ -314,37 +312,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(AppQueueEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
start_index: int # start index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
@@ -376,37 +343,6 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_EXCEPTION
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
|
||||
@@ -52,7 +52,6 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
@@ -214,7 +213,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: Optional[int] = 0
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
@@ -343,75 +341,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
||||
@@ -110,7 +110,7 @@ class RateLimitGenerator:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except Exception:
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
|
||||
@@ -12,10 +12,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -27,7 +25,6 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
@@ -167,55 +164,6 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
@@ -226,7 +174,6 @@ class WorkflowCycleManage:
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
@@ -246,7 +193,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = (
|
||||
@@ -273,9 +220,9 @@ class WorkflowCycleManage:
|
||||
|
||||
db.session.close()
|
||||
|
||||
# with Session(db.engine, expire_on_commit=False) as session:
|
||||
# session.add(workflow_run)
|
||||
# session.refresh(workflow_run)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
session.add(workflow_run)
|
||||
session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -371,7 +318,7 @@ class WorkflowCycleManage:
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
@@ -390,11 +337,7 @@ class WorkflowCycleManage:
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
),
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
||||
@@ -408,11 +351,8 @@ class WorkflowCycleManage:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
@@ -425,52 +365,6 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.index = event.start_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
@@ -539,7 +433,6 @@ class WorkflowCycleManage:
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -590,10 +483,7 @@ class WorkflowCycleManage:
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
@@ -635,51 +525,6 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
||||
@@ -42,31 +42,39 @@ def to_prompt_message_content(
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
):
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
params = {
|
||||
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
|
||||
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_class_map = {
|
||||
FileType.IMAGE: ImagePromptMessageContent,
|
||||
FileType.AUDIO: AudioPromptMessageContent,
|
||||
FileType.VIDEO: VideoPromptMessageContent,
|
||||
FileType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
try:
|
||||
return prompt_class_map[f.type](**params)
|
||||
except KeyError:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _get_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case FileType.VIDEO:
|
||||
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case FileType.DOCUMENT:
|
||||
data = _get_encoded_string(f)
|
||||
if f.mime_type is None:
|
||||
raise ValueError("Missing file mime_type")
|
||||
return DocumentPromptMessageContent(
|
||||
encode_format="base64",
|
||||
mime_type=f.mime_type,
|
||||
data=data,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
@@ -120,6 +128,11 @@ def _get_encoded_string(f: File, /):
|
||||
return encoded_string
|
||||
|
||||
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
@@ -128,7 +141,7 @@ def _to_url(f: File, /):
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
|
||||
@@ -24,12 +24,6 @@ BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@@ -45,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
@@ -69,7 +64,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
@@ -67,6 +67,7 @@ class PromptMessageContent(BaseModel):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
data: str
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@@ -75,35 +76,21 @@ class TextPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
data: str
|
||||
|
||||
|
||||
class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for multi-modal prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
format: str = Field(..., description="the format of multi-modal file")
|
||||
base64_data: str = Field("", description="the base64 data of multi-modal file")
|
||||
url: str = Field("", description="the url of multi-modal file")
|
||||
mime_type: str = Field(..., description="the mime type of multi-modal file")
|
||||
|
||||
@computed_field(return_type=str)
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
|
||||
|
||||
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
||||
class VideoPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
||||
data: str = Field(..., description="Base64 encoded video data")
|
||||
format: str = Field(..., description="Video format")
|
||||
|
||||
|
||||
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
"""
|
||||
@@ -116,8 +103,11 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
||||
class DocumentPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
||||
encode_format: Literal["base64"]
|
||||
mime_type: str
|
||||
data: str
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
@@ -17,6 +18,7 @@ from anthropic.types import (
|
||||
)
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage
|
||||
from httpx import Timeout
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities import (
|
||||
@@ -496,19 +498,22 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
if not message_content.base64_data:
|
||||
if not message_content.data.startswith("data:"):
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.url).content
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image data from url {message_content.data}, {ex}"
|
||||
)
|
||||
else:
|
||||
base64_data = message_content.base64_data
|
||||
data_split = message_content.data.split(";base64,")
|
||||
mime_type = data_split[0].replace("data:", "")
|
||||
base64_data = data_split[1]
|
||||
|
||||
mime_type = message_content.mime_type
|
||||
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
|
||||
raise ValueError(
|
||||
f"Unsupported image type {mime_type}, "
|
||||
@@ -529,7 +534,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
sub_message_dict = {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"type": message_content.encode_format,
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
},
|
||||
|
||||
@@ -819,82 +819,6 @@ LLM_BASE_MODELS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
||||
@@ -86,9 +86,6 @@ model_credential_schema:
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
- label:
|
||||
en_US: '2024-10-21'
|
||||
value: '2024-10-21'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
@@ -171,12 +168,6 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
||||
@@ -92,10 +92,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -10,7 +10,6 @@ from core.model_runtime.entities.llm_entities import (
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
@@ -106,11 +105,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_dict = {"role": "user", "content": message_content.data}
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
raise ValueError("Content object type not support image_url")
|
||||
raise ValueError("User message content must be str")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
service_name=service_name,
|
||||
config=client_config,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
else:
|
||||
# use iam without aksk to call
|
||||
client = boto3.client(service_name=service_name, config=client_config)
|
||||
|
||||
return client
|
||||
@@ -40,7 +40,6 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
@@ -174,7 +173,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
|
||||
bedrock_client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
region_name=credentials["aws_region"],
|
||||
)
|
||||
|
||||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@@ -11,7 +14,6 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
|
||||
class BedrockRerankModel(RerankModel):
|
||||
@@ -46,7 +48,13 @@ class BedrockRerankModel(RerankModel):
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-agent-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id", ""),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
||||
text_sources = []
|
||||
for text in docs:
|
||||
@@ -62,10 +70,7 @@ class BedrockRerankModel(RerankModel):
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
region = credentials["aws_region"]
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
@@ -23,7 +25,6 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,7 +48,14 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
@@ -88,10 +88,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
model: InternVL2-8B
|
||||
label:
|
||||
en_US: InternVL2-8B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -1,93 +0,0 @@
|
||||
model: InternVL2.5-26B
|
||||
label:
|
||||
en_US: InternVL2.5-26B
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
label:
|
||||
en_US: "Max Tokens"
|
||||
zh_Hans: "最大Token数"
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
|
||||
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
|
||||
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
en_US: "Top K"
|
||||
zh_Hans: "Top K"
|
||||
type: int
|
||||
default: 50
|
||||
min: 0
|
||||
max: 100
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
|
||||
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: -1.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
|
||||
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -6,5 +6,3 @@
|
||||
- deepseek-coder-33B-instruct-chat
|
||||
- deepseek-coder-33B-instruct-completions
|
||||
- codegeex4-all-9b
|
||||
- InternVL2.5-26B
|
||||
- InternVL2-8B
|
||||
|
||||
@@ -29,26 +29,18 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model, model_parameters)
|
||||
return super()._invoke(
|
||||
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model, None)
|
||||
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
|
||||
self._add_custom_parameters(credentials, None)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
|
||||
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
|
||||
if model is None:
|
||||
model = "Qwen2-72B-Instruct"
|
||||
|
||||
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
|
||||
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
|
||||
if model.endswith("completions"):
|
||||
credentials["mode"] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
@@ -13,8 +11,6 @@
|
||||
- gemini-1.5-flash-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0827
|
||||
- gemini-1.5-flash-8b-exp-0924
|
||||
- gemini-exp-1206
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
model: gemini-2.0-flash-exp
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Exp
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,39 +0,0 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -8,8 +8,6 @@ features:
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -7,9 +7,6 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -7,9 +7,6 @@ features:
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.generativeai as genai
|
||||
import requests
|
||||
from google.api_core import exceptions
|
||||
from google.generativeai.types import ContentType, File, GenerateContentResponse
|
||||
from google.generativeai.client import _ClientManager
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse
|
||||
from google.generativeai.types.content_types import to_part
|
||||
from PIL import Image
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
@@ -34,7 +35,21 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
GOOGLE_AVAILABLE_MIMETYPE = [
|
||||
"application/pdf",
|
||||
"application/x-javascript",
|
||||
"text/javascript",
|
||||
"application/x-python",
|
||||
"text/x-python",
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/md",
|
||||
"text/csv",
|
||||
"text/xml",
|
||||
"text/rtf",
|
||||
]
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
@@ -186,17 +201,29 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
# hack for gemini-pro-vision, which currently does not support multi-turn chat
|
||||
if model == "gemini-pro-vision":
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
# Create a new ClientManager with tenant's API key
|
||||
new_client_manager = _ClientManager()
|
||||
new_client_manager.configure(api_key=credentials["google_api_key"])
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
@@ -290,12 +317,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
||||
prompt_tokens = response.usage_metadata.prompt_token_count
|
||||
completion_tokens = response.usage_metadata.candidates_token_count
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
@@ -323,7 +346,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
|
||||
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
@@ -336,40 +359,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
|
||||
key = f"{message_content.type.value}:{hash(message_content.data)}"
|
||||
if redis_client.exists(key):
|
||||
try:
|
||||
return genai.get_file(redis_client.get(key).decode())
|
||||
except:
|
||||
pass
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
if message_content.base64_data:
|
||||
file_content = base64.b64decode(message_content.base64_data)
|
||||
temp_file.write(file_content)
|
||||
else:
|
||||
try:
|
||||
response = requests.get(message_content.url)
|
||||
response.raise_for_status()
|
||||
temp_file.write(response.content)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
|
||||
temp_file.flush()
|
||||
|
||||
file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
|
||||
while file.state.name == "PROCESSING":
|
||||
time.sleep(5)
|
||||
file = genai.get_file(file.name)
|
||||
# google will delete your upload files in 2 days.
|
||||
redis_client.setex(key, 47 * 60 * 60, file.name)
|
||||
|
||||
try:
|
||||
os.unlink(temp_file.name)
|
||||
except PermissionError:
|
||||
# windows may raise permission error
|
||||
pass
|
||||
return file
|
||||
|
||||
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
|
||||
"""
|
||||
Format a single message into glm.Content for Google API
|
||||
@@ -385,8 +374,28 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content["parts"].append(to_part(c.data))
|
||||
else:
|
||||
glm_content["parts"].append(self._upload_file_content_to_google(c))
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(",", 1)
|
||||
mime_type = metadata.split(";", 1)[0].split(":")[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
with Image.open(io.BytesIO(image_content)) as img:
|
||||
mime_type = f"image/{img.format.lower()}"
|
||||
base64_data = base64.b64encode(image_content).decode("utf-8")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
|
||||
glm_content["parts"].append(blob)
|
||||
elif c.type == PromptMessageContentType.DOCUMENT:
|
||||
message_content = cast(DocumentPromptMessageContent, c)
|
||||
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
|
||||
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
|
||||
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
|
||||
glm_content["parts"].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
- llama-3.1-405b-reasoning
|
||||
- llama-3.3-70b-versatile
|
||||
- llama-3.1-70b-versatile
|
||||
- llama-3.1-8b-instant
|
||||
- llama3-70b-8192
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
model: gemma-7b-it
|
||||
label:
|
||||
zh_Hans: Gemma 7B Instruction Tuned
|
||||
en_US: Gemma 7B Instruction Tuned
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,25 +0,0 @@
|
||||
model: gemma2-9b-it
|
||||
label:
|
||||
zh_Hans: Gemma 2 9B Instruction Tuned
|
||||
en_US: Gemma 2 9B Instruction Tuned
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,8 +1,7 @@
|
||||
model: llama-3.1-70b-versatile
|
||||
deprecated: true
|
||||
label:
|
||||
zh_Hans: Llama-3.1-70b-versatile (DEPRECATED)
|
||||
en_US: Llama-3.1-70b-versatile (DEPRECATED)
|
||||
zh_Hans: Llama-3.1-70b-versatile
|
||||
en_US: Llama-3.1-70b-versatile
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: llama-3.2-11b-text-preview
|
||||
deprecated: true
|
||||
label:
|
||||
zh_Hans: Llama 3.2 11B Text (Preview)
|
||||
en_US: Llama 3.2 11B Text (Preview)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: llama-3.2-90b-text-preview
|
||||
depraceted: true
|
||||
label:
|
||||
zh_Hans: Llama 3.2 90B Text (Preview)
|
||||
en_US: Llama 3.2 90B Text (Preview)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
model: llama-3.3-70b-specdec
|
||||
label:
|
||||
zh_Hans: Llama 3.3 70B Specdec
|
||||
en_US: Llama 3.3 70B Specdec
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@@ -1,25 +0,0 @@
|
||||
model: llama-3.3-70b-versatile
|
||||
label:
|
||||
zh_Hans: Llama 3.3 70B Versatile
|
||||
en_US: Llama 3.3 70B Versatile
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user