Compare commits

..

1 Commits

Author SHA1 Message Date
jyong
8d34db912f update tidb batch get endpoint to basic mode 2024-12-06 14:46:38 +08:00
545 changed files with 7440 additions and 15165 deletions

View File

@@ -7,6 +7,5 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
source /home/vscode/.bashrc
source /home/vscode/.bashrc

View File

@@ -50,9 +50,6 @@ jobs:
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh

View File

@@ -9,6 +9,5 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"

View File

@@ -37,7 +37,6 @@ jobs:
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
poetry run -C api ruff --version
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api

View File

@@ -51,7 +51,7 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
@@ -67,7 +67,6 @@ jobs:
pgvector
chroma
elasticsearch
tidb
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@@ -56,27 +56,20 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage
# S3 Storage configuration
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
@@ -86,7 +79,6 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@@ -133,8 +125,8 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -285,7 +277,6 @@ VIKINGDB_SOCKET_TIMEOUT=30
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
@@ -304,7 +295,8 @@ UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model configuration
MULTIMODAL_SEND_FORMAT=base64
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
@@ -389,8 +381,6 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
@@ -423,7 +413,3 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400

View File

@@ -70,6 +70,7 @@ ignore = [
"SIM113", # eumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"SIM300", # yoda-conditions,
]
[lint.per-file-ignores]

View File

@@ -1,30 +1,13 @@
from libs import version_utils
from app_factory import create_app
from libs import threadings_utils, version_utils
# preparation before creating app
version_utils.check_supported_python_version()
def is_db_command():
import sys
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
return True
return False
threadings_utils.apply_gevent_threading_patch()
# create app
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
else:
from app_factory import create_app
from libs import threadings_utils
threadings_utils.apply_gevent_threading_patch()
app = create_app()
celery = app.extensions["celery"]
app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@@ -1,4 +1,5 @@
import logging
import os
import time
from configs import dify_config
@@ -16,6 +17,15 @@ def create_flask_app_with_configs() -> DifyApp:
dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump())
# populate configs into system environment variables
for key, value in dify_app.config.items():
if isinstance(value, str):
os.environ[key] = value
elif isinstance(value, int | float | bool):
os.environ[key] = str(value)
elif value is None:
os.environ[key] = ""
return dify_app
@@ -88,14 +98,3 @@ def initialize_extensions(app: DifyApp):
end_time = time.perf_counter()
if dify_config.DEBUG:
logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)")
def create_migrations_app():
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate
# Initialize only required extensions
ext_database.init_app(app)
ext_migrate.init_app(app)
return app

View File

@@ -1,51 +1,11 @@
import logging
from typing import Any
from pydantic_settings import SettingsConfigDict
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource
logger = logging.getLogger(__name__)
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
if not remote_source_name:
return {}
remote_source: RemoteSettingsSource | None = None
match remote_source_name:
case RemoteSettingsSourceName.APOLLO:
remote_source = ApolloSettingsSource(current_state)
case _:
logger.warning(f"Unsupported remote source: {remote_source_name}")
return {}
d: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
if field_value is not None:
d[field_key] = field_value
return d
from configs.deploy import DeploymentConfig
from configs.enterprise import EnterpriseFeatureConfig
from configs.extra import ExtraServiceConfig
from configs.feature import FeatureConfig
from configs.middleware import MiddlewareConfig
from configs.packaging import PackagingInfo
class DifyConfig(
@@ -59,8 +19,6 @@ class DifyConfig(
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Remote source configs
RemoteSettingsSourceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@@ -77,20 +35,3 @@ class DifyConfig(
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
)

View File

@@ -439,17 +439,6 @@ class WorkflowConfig(BaseSettings):
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""
Configuration for workflow node execution
"""
MAX_SUBMIT_COUNT: PositiveInt = Field(
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
default=100,
)
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
@@ -485,11 +474,6 @@ class AuthConfig(BaseSettings):
default=60,
)
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
)
class ModerationConfig(BaseSettings):
"""
@@ -665,9 +649,14 @@ class IndexingConfig(BaseSettings):
)
class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending files in multimodal contexts ('base64' or 'url'), default is base64",
class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
@@ -773,20 +762,19 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,
MailConfig,
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
LoginConfig,
# hosted services config

View File

@@ -1,69 +1,54 @@
from typing import Any, Literal, Optional
from typing import Any, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig
from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.lindorm_config import LindormConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
] = Field(
STORAGE_TYPE: str = Field(
description="Type of storage to use."
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
default="local",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
default="storage",
deprecated=True,
)
@@ -88,7 +73,7 @@ class KeywordStoreConfig(BaseSettings):
)
class DatabaseConfig(BaseSettings):
class DatabaseConfig:
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@@ -250,7 +235,6 @@ class MiddlewareConfig(
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
OpenDALStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class BaiduOBSStorageConfig(BaseSettings):
class BaiduOBSStorageConfig(BaseModel):
"""
Configuration settings for Baidu Object Storage Service (OBS)
"""

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseSettings):
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
"""

View File

@@ -1,9 +0,0 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OpenDALStorageConfig(BaseSettings):
OPENDAL_SCHEME: str = Field(
default="fs",
description="OpenDAL scheme.",
)

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class SupabaseStorageConfig(BaseSettings):
class SupabaseStorageConfig(BaseModel):
"""
Configuration settings for Supabase Object Storage Service
"""

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseSettings):
class VolcengineTOSStorageConfig(BaseModel):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
"""

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field, PositiveInt
class AnalyticdbConfig(BaseSettings):
class AnalyticdbConfig(BaseModel):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Refer to the following documentation for details on obtaining credentials:

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class CouchbaseConfig(BaseSettings):
class CouchbaseConfig(BaseModel):
"""
Couchbase configs
"""

View File

@@ -21,14 +21,3 @@ class LindormConfig(BaseSettings):
description="Lindorm password",
default=None,
)
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)

View File

@@ -1,8 +1,7 @@
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field, PositiveInt
class MyScaleConfig(BaseSettings):
class MyScaleConfig(BaseModel):
"""
Configuration settings for MyScale vector database
"""

View File

@@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class VikingDBConfig(BaseSettings):
class VikingDBConfig(BaseModel):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.14.1",
default="0.13.1",
)
COMMIT_SHA: str = Field(

View File

@@ -1,17 +0,0 @@
from typing import Optional
from pydantic import Field
from .apollo import ApolloSettingsSourceInfo
from .base import RemoteSettingsSource
from .enums import RemoteSettingsSourceName
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
description="name of remote config source",
default="",
)
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]

View File

@@ -1,55 +0,0 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import Field
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings
from configs.remote_settings_sources.base import RemoteSettingsSource
from .client import ApolloClient
class ApolloSettingsSourceInfo(BaseSettings):
"""
Packaging build information
"""
APOLLO_APP_ID: Optional[str] = Field(
description="apollo app_id",
default=None,
)
APOLLO_CLUSTER: Optional[str] = Field(
description="apollo cluster",
default=None,
)
APOLLO_CONFIG_URL: Optional[str] = Field(
description="apollo config url",
default=None,
)
APOLLO_NAMESPACE: Optional[str] = Field(
description="apollo namespace",
default=None,
)
class ApolloSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.client = ApolloClient(
app_id=configs["APOLLO_APP_ID"],
cluster=configs["APOLLO_CLUSTER"],
config_url=configs["APOLLO_CONFIG_URL"],
start_hot_update=False,
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
)
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
return field_value, field_name, False

View File

@@ -1,303 +0,0 @@
import hashlib
import json
import logging
import os
import threading
import time
from pathlib import Path
from .python_3x import http_request, makedirs_wrapper
from .utils import (
CONFIGURATIONS,
NAMESPACE_NAME,
NOTIFICATION_ID,
get_value_from_dict,
init_ip,
no_key_cache_key,
signature,
url_encode_wrapper,
)
logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
):
# Core routing parameters
self.config_url = config_url
self.cluster = cluster
self.app_id = app_id
# Non-core parameters
self.ip = init_ip()
self.secret = secret
# Check the parameter variables
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key = None
# Private startup method
self._path_checker()
if start_hot_update:
self._start_hot_update()
# start the heartbeat thread
heartbeat = threading.Thread(target=self._heart_beat)
heartbeat.daemon = True
heartbeat.start()
def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"get_json_from_net load configs failed, body is {body}")
return None
data = json.loads(body)
data = data["configurations"]
return_data = {CONFIGURATIONS: data}
return return_data
else:
return None
except Exception:
logger.exception("an error occurred in get_json_from_net")
return None
def get_value(self, key, default_val=None, namespace="application"):
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
return val
no_key = no_key_cache_key(namespace, key)
if no_key in self._no_key:
return default_val
# read the network configuration
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
self._update_cache_and_file(namespace_data, namespace)
return val
# read the file configuration
namespace_cache = self._get_local_cache(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
self._update_cache_and_file(namespace_cache, namespace)
return val
# If all of them are not obtained, the default value is returned
# and the local cache is set to None
self._set_local_cache_none(namespace, key)
return default_val
except Exception:
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
return default_val
# Set the key of a namespace to none, and do not set default val
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key
def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()
def stop(self):
self._stopping = True
logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None:
return
if old_kv is None:
old_kv = {}
if new_kv is None:
new_kv = {}
try:
for key in old_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if new_value is None:
# If newValue is empty, it means key, and the value is deleted.
self._change_listener("delete", namespace, key, old_value)
continue
if new_value != old_value:
self._change_listener("update", namespace, key, new_value)
continue
for key in new_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if old_value is None:
self._change_listener("add", namespace, key, new_value)
except BaseException as e:
logger.warning(str(e))
def _path_checker(self):
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
new_string = json.dumps(namespace_data)
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
if self._hash.get(namespace) == new_hash:
pass
else:
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
file_path.write_text(new_string)
self._hash[namespace] = new_hash
# get the configuration from the local file
def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
result = json.loads(f.readline())
return result
return {}
def _long_poll(self):
notifications = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
if NOTIFICATION_ID in namespace_data:
notification_id = self._cache[key][NOTIFICATION_ID]
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
try:
# if the length is 0 it is returned directly
if len(notifications) == 0:
return
url = "{}/notifications/v2".format(self.config_url)
params = {
"appId": self.app_id,
"cluster": self.cluster,
"notifications": json.dumps(notifications, ensure_ascii=False),
}
param_str = url_encode_wrapper(params)
url = url + "?" + param_str
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
http_code = code
if http_code == 304:
logger.debug("No change, loop...")
return
if http_code == 200:
if not body:
logger.error(f"_long_poll load configs failed,body is {body}")
return
data = json.loads(body)
for entry in data:
namespace = entry[NAMESPACE_NAME]
n_id = entry[NOTIFICATION_ID]
logger.info("%s has changes: notificationId=%d", namespace, n_id)
self._get_net_and_set_local(namespace, n_id, call_change=True)
return
else:
logger.warning("Sleep...")
except Exception as e:
logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
namespace_data[NOTIFICATION_ID] = n_id
old_namespace = self._cache.get(namespace)
self._update_cache_and_file(namespace_data, namespace)
if self._change_listener is not None and call_change and old_namespace:
old_kv = old_namespace.get(CONFIGURATIONS)
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)
def _listener(self):
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
time.sleep(self._cycle_time)
logger.info("stopped, long_poll")
# add the need for endorsement to the header
def _sign_headers(self, url):
headers = {}
if self.secret == "":
return headers
uri = url[len(self.config_url) : len(url)]
time_unix_now = str(int(round(time.time() * 1000)))
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
headers["Timestamp"] = time_unix_now
return headers
def _heart_beat(self):
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟
def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"_do_heart_beat load configs failed,body is {body}")
return None
data = json.loads(body)
if self.last_release_key == data["releaseKey"]:
return None
self.last_release_key = data["releaseKey"]
data = data["configurations"]
self._update_cache_and_file(data, namespace)
else:
return None
except Exception:
logger.exception("an error occurred in _do_heart_beat")
return None
def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)
if not net_namespace_data:
return namespace_data
namespace_data = net_namespace_data.get(CONFIGURATIONS)
if namespace_data:
self._update_cache_and_file(namespace_data, namespace)
return namespace_data

View File

@@ -1,41 +0,0 @@
import logging
import os
import ssl
import urllib.request
from urllib import parse
from urllib.error import HTTPError
# Create an SSL context that allows for a lower level of security
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Create an opener object and pass in a custom SSL context
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}):
try:
request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
except HTTPError as e:
if e.code == 304:
logger.warning("http_request error,code is 304, maybe you should check secret")
return 304, None
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
raise e
def url_encode(params):
return parse.urlencode(params)
def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True)

View File

@@ -1,51 +0,0 @@
import hashlib
import socket
from .python_3x import url_encode
# define constants
CONFIGURATIONS = "configurations"
NOTIFICATION_ID = "notificationId"
NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys
def signature(timestamp, uri, secret):
import base64
import hmac
string_to_sign = "" + timestamp + "\n" + uri
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params):
return url_encode(params)
def no_key_cache_key(namespace, key):
return "{}{}{}".format(namespace, len(namespace), key)
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
return None
if key in kv_data:
return kv_data[key]
return None
def init_ip():
ip = ""
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 53))
ip = s.getsockname()[0]
finally:
if s:
s.close()
return ip

View File

@@ -1,15 +0,0 @@
from collections.abc import Mapping
from typing import Any
from pydantic.fields import FieldInfo
class RemoteSettingsSource:
def __init__(self, configs: Mapping[str, Any]):
pass
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value

View File

@@ -1,5 +0,0 @@
from enum import StrEnum
class RemoteSettingsSourceName(StrEnum):
APOLLO = "apollo"

View File

@@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@@ -31,7 +31,7 @@ def admin_required(view):
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY:
if dify_config.ADMIN_API_KEY != auth_token:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)

View File

@@ -65,7 +65,7 @@ class ModelConfigResource(Resource):
provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}",
)
except Exception:
except Exception as e:
continue
# get decrypted parameters
@@ -97,7 +97,7 @@ class ModelConfigResource(Resource):
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception:
except Exception as e:
continue
manager = ToolParameterConfigurationManager(

View File

@@ -1,5 +1,4 @@
from flask_restful import Resource, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
@@ -27,7 +26,7 @@ class TraceAppConfigApi(Resource):
return {"has_not_configured": True}
return trace_config
except Exception as e:
raise BadRequest(str(e))
raise e
@setup_required
@login_required
@@ -49,7 +48,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigCheckError()
return result
except Exception as e:
raise BadRequest(str(e))
raise e
@setup_required
@login_required
@@ -69,7 +68,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigNotExist()
return {"result": "success"}
except Exception as e:
raise BadRequest(str(e))
raise e
@setup_required
@login_required
@@ -86,7 +85,7 @@ class TraceAppConfigApi(Resource):
raise TracingConfigNotExist()
return {"result": "success"}
except Exception as e:
raise BadRequest(str(e))
raise e
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config")

View File

@@ -948,7 +948,7 @@ class DocumentRetryApi(DocumentResource):
if document.indexing_status == "completed":
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:
except Exception as e:
logging.exception(f"Failed to retry document, document id: {document_id}")
continue
# retry document

View File

@@ -1,6 +1,5 @@
from datetime import UTC, datetime
from flask import request
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
@@ -21,17 +20,8 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
current_tenant_id = current_user.current_tenant_id
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [

View File

@@ -4,7 +4,6 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField
from libs.login import login_required
from services.recommended_app_service import RecommendedAppService
@@ -13,8 +12,6 @@ app_fields = {
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}

View File

@@ -1,7 +1,6 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
from configs import dify_config
@@ -59,9 +58,6 @@ class FileApi(Resource):
if not file.filename:
raise FilenameNotExistsError
if source == "datasets" and not current_user.is_dataset_editor:
raise Forbidden()
if source not in ("datasets", None):
source = None

View File

@@ -368,7 +368,6 @@ class ToolWorkflowProviderCreateApi(Resource):
description=args["description"],
parameters=args["parameters"],
privacy_policy=args["privacy_policy"],
labels=args["labels"],
)

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Optional, Union
from flask import Flask, current_app
from pydantic import ValidationError
@@ -36,29 +36,6 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
@@ -67,17 +44,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

View File

@@ -19,10 +19,8 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -33,7 +31,6 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -130,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
self.total_tokens: int = 0
def process(self):
"""
@@ -320,7 +318,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
@@ -329,22 +327,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
@@ -379,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@@ -392,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation.id,
@@ -403,29 +387,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@@ -443,7 +404,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(

View File

@@ -2,7 +2,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from typing import Any, Union
from flask import Flask, current_app
from pydantic import ValidationError
@@ -28,39 +28,6 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
@@ -69,7 +36,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Mapping[str, Any] | Generator[str, None, None]:
"""
Generate App response.

View File

@@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
for resource in metadata["retriever_resources"]:
updated_resources.append(
{
"segment_id": resource.get("segment_id", ""),
"segment_id": resource["segment_id"],
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],

View File

@@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Any,
invoke_from: InvokeFrom,
streaming: Literal[True],
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
@@ -44,29 +44,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Any,
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
stream: Literal[False] = False,
) -> dict: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: Any,
invoke_from: InvokeFrom,
streaming: bool = True,
):
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.

View File

@@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: dict,
invoke_from: InvokeFrom,
streaming: Literal[True],
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
@@ -44,29 +44,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
args: dict,
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
stream: Literal[False] = False,
) -> dict: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload
from typing import Any, Optional, Union
from flask import Flask, current_app
from pydantic import ValidationError
@@ -30,35 +30,6 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
@@ -70,20 +41,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
) -> Mapping[str, Any] | Generator[str, None, None]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files

View File

@@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
@@ -35,8 +34,7 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
| QueueWorkflowFailedEvent,
):
self.stop_listen()

View File

@@ -15,10 +15,8 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -28,7 +26,6 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -109,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -260,44 +258,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
node_start_response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
node_success_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
node_failed_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -305,7 +266,28 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
@@ -338,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@@ -351,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
@@ -361,30 +345,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
@@ -394,6 +354,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@@ -405,7 +366,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log

View File

@@ -8,10 +8,8 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -20,7 +18,6 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -28,7 +25,6 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@@ -36,10 +32,8 @@ from core.workflow.graph_engine.entities.event import (
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -182,12 +176,8 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@@ -263,36 +253,6 @@ class WorkflowBasedAppRunner(AppRunner):
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
@@ -422,36 +382,6 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
start_index=event.start_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
@@ -25,14 +25,12 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
@@ -43,7 +41,6 @@ class QueueEvent(StrEnum):
ERROR = "error"
PING = "ping"
STOP = "stop"
RETRY = "retry"
class AppQueueEvent(BaseModel):
@@ -116,6 +113,18 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""
@@ -240,17 +249,6 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
class QueueNodeStartedEvent(AppQueueEvent):
@@ -314,37 +312,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueNodeRetryEvent(AppQueueEvent):
"""QueueNodeRetryEvent entity"""
event: QueueEvent = QueueEvent.RETRY
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
retry_index: int # retry index
start_index: int # start index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
@@ -376,37 +343,6 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""
event: QueueEvent = QueueEvent.NODE_EXCEPTION
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity

View File

@@ -52,7 +52,6 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
@@ -214,7 +213,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
@@ -343,75 +341,6 @@ class NodeFinishStreamResponse(StreamResponse):
}
class NodeRetryStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
"workflow_run_id": self.workflow_run_id,
"data": {
"id": self.data.id,
"node_id": self.data.node_id,
"node_type": self.data.node_type,
"title": self.data.title,
"index": self.data.index,
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"process_data": None,
"outputs": None,
"status": self.data.status,
"error": None,
"elapsed_time": self.data.elapsed_time,
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"retry_index": self.data.retry_index,
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity

View File

@@ -110,7 +110,7 @@ class RateLimitGenerator:
raise StopIteration
try:
return next(self.generator)
except Exception:
except StopIteration:
self.close()
raise

View File

@@ -12,10 +12,8 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -27,7 +25,6 @@ from core.app.entities.task_entities import (
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
@@ -167,55 +164,6 @@ class WorkflowCycleManage:
return workflow_run
def _handle_workflow_run_partial_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
db.session.close()
return workflow_run
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
@@ -226,7 +174,6 @@ class WorkflowCycleManage:
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
@@ -246,7 +193,7 @@ class WorkflowCycleManage:
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
running_workflow_node_executions = (
@@ -273,9 +220,9 @@ class WorkflowCycleManage:
db.session.close()
# with Session(db.engine, expire_on_commit=False) as session:
# session.add(workflow_run)
# session.refresh(workflow_run)
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
@@ -371,7 +318,7 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@@ -390,11 +337,7 @@ class WorkflowCycleManage:
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
@@ -408,11 +351,8 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
@@ -425,52 +365,6 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.index = event.start_index
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
return workflow_node_execution
#################################################
# to stream responses #
#################################################
@@ -539,7 +433,6 @@ class WorkflowCycleManage:
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -590,10 +483,7 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeExceptionEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@@ -635,51 +525,6 @@ class WorkflowCycleManage:
),
)
def _workflow_node_retry_to_stream_response(
self,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:

View File

@@ -42,31 +42,39 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
match f.type:
case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
try:
return prompt_class_map[f.type](**params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO:
encoded_string = _get_encoded_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /):
@@ -120,6 +128,11 @@ def _get_encoded_string(f: File, /):
return encoded_string
def _to_base64_data_string(f: File, /):
encoded_string = _get_encoded_string(f)
return f"data:{f.mime_type};base64,{encoded_string}"
def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None:
@@ -128,7 +141,7 @@ def _to_url(f: File, /):
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
return helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:

View File

@@ -24,12 +24,6 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
"""Raised when the maximum number of retries is exceeded."""
pass
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@@ -45,6 +39,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
@@ -69,7 +64,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@@ -1,9 +1,9 @@
from abc import ABC
from collections.abc import Sequence
from enum import Enum, StrEnum
from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic import BaseModel, Field, field_validator
class PromptMessageRole(Enum):
@@ -67,6 +67,7 @@ class PromptMessageContent(BaseModel):
"""
type: PromptMessageContentType
data: str
class TextPromptMessageContent(PromptMessageContent):
@@ -75,35 +76,21 @@ class TextPromptMessageContent(PromptMessageContent):
"""
type: PromptMessageContentType = PromptMessageContentType.TEXT
data: str
class MultiModalPromptMessageContent(PromptMessageContent):
"""
Model class for multi-modal prompt message content.
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
class VideoPromptMessageContent(MultiModalPromptMessageContent):
class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(MultiModalPromptMessageContent):
class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")
format: str = Field(..., description="Audio format")
class ImagePromptMessageContent(MultiModalPromptMessageContent):
class ImagePromptMessageContent(PromptMessageContent):
"""
Model class for image prompt message content.
"""
@@ -116,8 +103,11 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel):

View File

@@ -1,4 +1,5 @@
import base64
import io
import json
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
@@ -17,6 +18,7 @@ from anthropic.types import (
)
from anthropic.types.beta.tools import ToolsBetaMessage
from httpx import Timeout
from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities import (
@@ -496,19 +498,22 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.base64_data:
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.url).content
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(
f"Failed to fetch image data from url {message_content.data}, {ex}"
)
else:
base64_data = message_content.base64_data
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
mime_type = message_content.mime_type
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError(
f"Unsupported image type {mime_type}, "
@@ -529,7 +534,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
sub_message_dict = {
"type": "document",
"source": {
"type": "base64",
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
},

View File

@@ -819,82 +819,6 @@ LLM_BASE_MODELS = [
),
),
),
AzureBaseModel(
base_model_name="gpt-4o-2024-11-20",
entity=AIModelEntity(
model="fake-deployment-name",
label=I18nObject(
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=16384),
ParameterRule(
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
),
required=False,
options=["text", "json_object", "json_schema"],
),
ParameterRule(
name="json_schema",
label=I18nObject(en_US="JSON Schema"),
type="text",
help=I18nObject(
zh_Hans="设置返回的json schemallm将按照它返回",
en_US="Set a response json schema will ensure LLM to adhere it.",
),
required=False,
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency="USD",
),
),
),
AzureBaseModel(
base_model_name="gpt-4-turbo",
entity=AIModelEntity(

View File

@@ -86,9 +86,6 @@ model_credential_schema:
- label:
en_US: '2024-06-01'
value: '2024-06-01'
- label:
en_US: '2024-10-21'
value: '2024-10-21'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
@@ -171,12 +168,6 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-11-20
value: gpt-4o-2024-11-20
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo

View File

@@ -92,10 +92,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
embeddings[i] = (average / np.linalg.norm(average)).tolist()
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -10,7 +10,6 @@ from core.model_runtime.entities.llm_entities import (
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
@@ -106,11 +105,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_dict = {"role": "user", "content": message_content.data}
elif message_content.type == PromptMessageContentType.IMAGE:
raise ValueError("Content object type not support image_url")
raise ValueError("User message content must be str")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}

View File

@@ -1,29 +0,0 @@
from collections.abc import Mapping
import boto3
from botocore.config import Config
from core.model_runtime.errors.invoke import InvokeBadRequestError
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
region_name = credentials.get("aws_region")
if not region_name:
raise InvokeBadRequestError("aws_region is required")
client_config = Config(region_name=region_name)
aws_access_key_id = credentials.get("aws_access_key_id")
aws_secret_access_key = credentials.get("aws_secret_access_key")
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
else:
# use iam without aksk to call
client = boto3.client(service_name=service_name, config=client_config)
return client

View File

@@ -40,7 +40,6 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
@@ -174,7 +173,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

View File

@@ -1,5 +1,8 @@
from typing import Optional
import boto3
from botocore.config import Config
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@@ -11,7 +14,6 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
class BedrockRerankModel(RerankModel):
@@ -46,7 +48,13 @@ class BedrockRerankModel(RerankModel):
return RerankResult(model=model, docs=docs)
# initialize client
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:
@@ -62,10 +70,7 @@ class BedrockRerankModel(RerankModel):
}
)
modelId = model
region = credentials.get("aws_region")
# region is a required field
if not region:
raise InvokeBadRequestError("aws_region is required in credentials")
region = credentials["aws_region"]
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",

View File

@@ -3,6 +3,8 @@ import logging
import time
from typing import Optional
import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
@@ -23,7 +25,6 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
@@ -47,7 +48,14 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
embeddings = []
token_usage = 0

View File

@@ -88,10 +88,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
embeddings[i] = (average / np.linalg.norm(average)).tolist()
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@@ -1,93 +0,0 @@
model: InternVL2-8B
label:
en_US: InternVL2-8B
model_type: llm
features:
- vision
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -1,93 +0,0 @@
model: InternVL2.5-26B
label:
en_US: InternVL2.5-26B
model_type: llm
features:
- vision
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: max_tokens
use_template: max_tokens
label:
en_US: "Max Tokens"
zh_Hans: "最大Token数"
type: int
default: 512
min: 1
required: true
help:
en_US: "The maximum number of tokens that can be generated by the model varies depending on the model."
zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。"
- name: temperature
use_template: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
use_template: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_k
use_template: top_k
label:
en_US: "Top K"
zh_Hans: "Top K"
type: int
default: 50
min: 0
max: 100
required: true
help:
en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be."
zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: -1.0
max: 1.0
precision: 1
required: false
help:
en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation."
zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -6,5 +6,3 @@
- deepseek-coder-33B-instruct-chat
- deepseek-coder-33B-instruct-completions
- codegeex4-all-9b
- InternVL2.5-26B
- InternVL2-8B

View File

@@ -29,26 +29,18 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials, model, model_parameters)
return super()._invoke(
GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model),
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials, model, None)
super().validate_credentials(GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model), credentials)
self._add_custom_parameters(credentials, None)
super().validate_credentials(model, credentials)
def _add_custom_parameters(self, credentials: dict, model: Optional[str], model_parameters: dict) -> None:
def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "Qwen2-72B-Instruct"
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model)
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/"
if model.endswith("completions"):
credentials["mode"] = LLMMode.COMPLETION.value
else:

View File

@@ -1,5 +1,3 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001
@@ -13,8 +11,6 @@
- gemini-1.5-flash-exp-0827
- gemini-1.5-flash-8b-exp-0827
- gemini-1.5-flash-8b-exp-0924
- gemini-exp-1206
- gemini-exp-1121
- gemini-exp-1114
- gemini-pro
- gemini-pro-vision

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152

View File

@@ -1,41 +0,0 @@
model: gemini-2.0-flash-exp
label:
en_US: Gemini 2.0 Flash Exp
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -1,39 +0,0 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -8,8 +8,6 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -7,9 +7,6 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -1,41 +0,0 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -7,9 +7,6 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767

View File

@@ -1,23 +1,24 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
@@ -34,7 +35,21 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_redis import redis_client
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
class GoogleLargeLanguageModel(LargeLanguageModel):
@@ -186,17 +201,29 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
response = google_model.generate_content(
contents=history,
@@ -290,12 +317,8 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
else:
# calculate num tokens
if hasattr(response, "usage_metadata") and response.usage_metadata:
prompt_tokens = response.usage_metadata.prompt_token_count
completion_tokens = response.usage_metadata.candidates_token_count
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@@ -323,7 +346,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
@@ -336,40 +359,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
return message_text
def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.base64_data:
file_content = base64.b64decode(message_content.base64_data)
temp_file.write(file_content)
else:
try:
response = requests.get(message_content.url)
response.raise_for_status()
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.url}, {ex}")
temp_file.flush()
file = genai.upload_file(path=temp_file.name, mime_type=message_content.mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
try:
os.unlink(temp_file.name)
except PermissionError:
# windows may raise permission error
pass
return file
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
@@ -385,8 +374,28 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):

View File

@@ -1,5 +1,4 @@
- llama-3.1-405b-reasoning
- llama-3.3-70b-versatile
- llama-3.1-70b-versatile
- llama-3.1-8b-instant
- llama3-70b-8192

View File

@@ -1,25 +0,0 @@
model: gemma-7b-it
label:
zh_Hans: Gemma 7B Instruction Tuned
en_US: Gemma 7B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@@ -1,25 +0,0 @@
model: gemma2-9b-it
label:
zh_Hans: Gemma 2 9B Instruction Tuned
en_US: Gemma 2 9B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@@ -1,8 +1,7 @@
model: llama-3.1-70b-versatile
deprecated: true
label:
zh_Hans: Llama-3.1-70b-versatile (DEPRECATED)
en_US: Llama-3.1-70b-versatile (DEPRECATED)
zh_Hans: Llama-3.1-70b-versatile
en_US: Llama-3.1-70b-versatile
model_type: llm
features:
- agent-thought

View File

@@ -1,5 +1,4 @@
model: llama-3.2-11b-text-preview
deprecated: true
label:
zh_Hans: Llama 3.2 11B Text (Preview)
en_US: Llama 3.2 11B Text (Preview)

View File

@@ -1,5 +1,4 @@
model: llama-3.2-90b-text-preview
depraceted: true
label:
zh_Hans: Llama 3.2 90B Text (Preview)
en_US: Llama 3.2 90B Text (Preview)

View File

@@ -1,25 +0,0 @@
model: llama-3.3-70b-specdec
label:
zh_Hans: Llama 3.3 70B Specdec
en_US: Llama 3.3 70B Specdec
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

View File

@@ -1,25 +0,0 @@
model: llama-3.3-70b-versatile
label:
zh_Hans: Llama 3.3 70B Versatile
en_US: Llama 3.3 70B Versatile
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

Some files were not shown because too many files have changed in this diff Show More