mirror of
https://github.com/langgenius/dify.git
synced 2026-01-10 08:14:14 +00:00
Compare commits
3 Commits
0.14.0
...
feat/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac0d99281e | ||
|
|
bbdadec1bc | ||
|
|
fa9709faa8 |
@@ -7,6 +7,5 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
source /home/vscode/.bashrc
|
||||
13
.github/pull_request_template.md
vendored
13
.github/pull_request_template.md
vendored
@@ -8,9 +8,16 @@ Please include a summary of the change and which issue is fixed. Please also inc
|
||||
|
||||
# Screenshots
|
||||
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
<table>
|
||||
<tr>
|
||||
<td>Before: </td>
|
||||
<td>After: </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>...</td>
|
||||
<td>...</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
# Checklist
|
||||
|
||||
|
||||
1
.github/workflows/style.yml
vendored
1
.github/workflows/style.yml
vendored
@@ -37,7 +37,6 @@ jobs:
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
|
||||
|
||||
3
.github/workflows/vdb-tests.yml
vendored
3
.github/workflows/vdb-tests.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@@ -56,36 +56,20 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=opendal
|
||||
|
||||
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
|
||||
STORAGE_OPENDAL_SCHEME=fs
|
||||
# OpenDAL FS
|
||||
OPENDAL_FS_ROOT=storage
|
||||
# OpenDAL S3
|
||||
OPENDAL_S3_ROOT=/
|
||||
OPENDAL_S3_BUCKET=your-bucket-name
|
||||
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
|
||||
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
|
||||
OPENDAL_S3_REGION=your-region
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
|
||||
|
||||
# S3 Storage configuration
|
||||
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
|
||||
S3_BUCKET_NAME=your-bucket-name
|
||||
S3_ACCESS_KEY=your-access-key
|
||||
S3_SECRET_KEY=your-secret-key
|
||||
S3_REGION=your-region
|
||||
|
||||
# Azure Blob Storage configuration
|
||||
AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
@@ -95,7 +79,6 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
@@ -142,8 +125,8 @@ SUPABASE_URL=your-server-url
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -294,7 +277,6 @@ VIKINGDB_SOCKET_TIMEOUT=30
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
@@ -399,8 +381,6 @@ LOG_FILE_BACKUP_COUNT=5
|
||||
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
|
||||
# Log Timezone
|
||||
LOG_TZ=UTC
|
||||
# Log format
|
||||
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
@@ -433,5 +413,3 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
|
||||
@@ -259,7 +259,7 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
upper_collection_vector_types = {
|
||||
upper_colletion_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.RELYT,
|
||||
@@ -267,7 +267,7 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
lower_colletion_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
@@ -307,7 +307,7 @@ def migrate_knowledge_vector_database():
|
||||
continue
|
||||
collection_name = ""
|
||||
dataset_id = dataset.id
|
||||
if vector_type in upper_collection_vector_types:
|
||||
if vector_type in upper_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
@@ -323,7 +323,7 @@ def migrate_knowledge_vector_database():
|
||||
else:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
|
||||
elif vector_type in lower_collection_vector_types:
|
||||
elif vector_type in lower_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
@@ -1,51 +1,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
from .packaging import PackagingInfo
|
||||
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
|
||||
from .remote_settings_sources.apollo import ApolloSettingsSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
super().__init__(settings_cls)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
if not remote_source_name:
|
||||
return {}
|
||||
|
||||
remote_source: RemoteSettingsSource | None = None
|
||||
match remote_source_name:
|
||||
case RemoteSettingsSourceName.APOLLO:
|
||||
remote_source = ApolloSettingsSource(current_state)
|
||||
case _:
|
||||
logger.warning(f"Unsupported remote source: {remote_source_name}")
|
||||
return {}
|
||||
|
||||
d: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
|
||||
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
if field_value is not None:
|
||||
d[field_key] = field_value
|
||||
|
||||
return d
|
||||
from configs.deploy import DeploymentConfig
|
||||
from configs.enterprise import EnterpriseFeatureConfig
|
||||
from configs.extra import ExtraServiceConfig
|
||||
from configs.feature import FeatureConfig
|
||||
from configs.middleware import MiddlewareConfig
|
||||
from configs.packaging import PackagingInfo
|
||||
|
||||
|
||||
class DifyConfig(
|
||||
@@ -59,8 +19,6 @@ class DifyConfig(
|
||||
MiddlewareConfig,
|
||||
# Extra service configs
|
||||
ExtraServiceConfig,
|
||||
# Remote source configs
|
||||
RemoteSettingsSourceConfig,
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
@@ -77,20 +35,3 @@ class DifyConfig(
|
||||
# please consider to arrange it in the proper config group of existed or added
|
||||
# for better readability and maintainability.
|
||||
# Thanks for your concentration and consideration.
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
RemoteSettingsSourceFactory(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
@@ -439,17 +439,6 @@ class WorkflowConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow node execution
|
||||
"""
|
||||
|
||||
MAX_SUBMIT_COUNT: PositiveInt = Field(
|
||||
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
|
||||
default=100,
|
||||
)
|
||||
|
||||
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for authentication and OAuth
|
||||
@@ -786,7 +775,6 @@ class FeatureConfig(
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
# hosted services config
|
||||
|
||||
@@ -1,69 +1,54 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from .vdb.opensearch_config import OpenSearchConfig
|
||||
from .vdb.oracle_config import OracleConfig
|
||||
from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from .vdb.upstash_config import UpstashConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.lindorm_config import LindormConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
from configs.middleware.vdb.oracle_config import OracleConfig
|
||||
from configs.middleware.vdb.pgvector_config import PGVectorConfig
|
||||
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
|
||||
from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
STORAGE_TYPE: str = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
|
||||
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
|
||||
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
|
||||
default="storage",
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,7 +73,7 @@ class KeywordStoreConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
class DatabaseConfig:
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@@ -250,7 +235,6 @@ class MiddlewareConfig(
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
OpenDALStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaiduOBSStorageConfig(BaseSettings):
|
||||
class BaiduOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Baidu Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseSettings):
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Huawei Cloud Object Storage Service (OBS)
|
||||
"""
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenDALScheme(StrEnum):
|
||||
FS = "fs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
class OpenDALStorageConfig(BaseSettings):
|
||||
STORAGE_OPENDAL_SCHEME: str = Field(
|
||||
default=OpenDALScheme.FS.value,
|
||||
description="OpenDAL scheme.",
|
||||
)
|
||||
# FS
|
||||
OPENDAL_FS_ROOT: str = Field(
|
||||
default="storage",
|
||||
description="Root path for local storage.",
|
||||
)
|
||||
# S3
|
||||
OPENDAL_S3_ROOT: str = Field(
|
||||
default="/",
|
||||
description="Root path for S3 storage.",
|
||||
)
|
||||
OPENDAL_S3_BUCKET: str = Field(
|
||||
default="",
|
||||
description="S3 bucket name.",
|
||||
)
|
||||
OPENDAL_S3_ENDPOINT: str = Field(
|
||||
default="https://s3.amazonaws.com",
|
||||
description="S3 endpoint URL.",
|
||||
)
|
||||
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
|
||||
default="",
|
||||
description="S3 access key ID.",
|
||||
)
|
||||
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
|
||||
default="",
|
||||
description="S3 secret access key.",
|
||||
)
|
||||
OPENDAL_S3_REGION: str = Field(
|
||||
default="",
|
||||
description="S3 region.",
|
||||
)
|
||||
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
|
||||
default="",
|
||||
description="S3 server-side encryption.",
|
||||
)
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseSettings):
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseSettings):
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Volcengine Tinder Object Storage (TOS)
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseSettings):
|
||||
class AnalyticdbConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CouchbaseConfig(BaseSettings):
|
||||
class CouchbaseConfig(BaseModel):
|
||||
"""
|
||||
Couchbase configs
|
||||
"""
|
||||
|
||||
@@ -21,14 +21,3 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: Optional[str] = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: Optional[bool] = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
|
||||
class MyScaleConfig(BaseSettings):
|
||||
class MyScaleConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for MyScale vector database
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseSettings):
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.14.0",
|
||||
default="0.13.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .apollo import ApolloSettingsSourceInfo
|
||||
from .base import RemoteSettingsSource
|
||||
from .enums import RemoteSettingsSourceName
|
||||
|
||||
|
||||
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
|
||||
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
|
||||
description="name of remote config source",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]
|
||||
@@ -1,55 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .client import ApolloClient
|
||||
|
||||
|
||||
class ApolloSettingsSourceInfo(BaseSettings):
|
||||
"""
|
||||
Packaging build information
|
||||
"""
|
||||
|
||||
APOLLO_APP_ID: Optional[str] = Field(
|
||||
description="apollo app_id",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CLUSTER: Optional[str] = Field(
|
||||
description="apollo cluster",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_CONFIG_URL: Optional[str] = Field(
|
||||
description="apollo config url",
|
||||
default=None,
|
||||
)
|
||||
|
||||
APOLLO_NAMESPACE: Optional[str] = Field(
|
||||
description="apollo namespace",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ApolloSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.client = ApolloClient(
|
||||
app_id=configs["APOLLO_APP_ID"],
|
||||
cluster=configs["APOLLO_CLUSTER"],
|
||||
config_url=configs["APOLLO_CONFIG_URL"],
|
||||
start_hot_update=False,
|
||||
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
|
||||
)
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
return field_value, field_name, False
|
||||
@@ -1,303 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from .python_3x import http_request, makedirs_wrapper
|
||||
from .utils import (
|
||||
CONFIGURATIONS,
|
||||
NAMESPACE_NAME,
|
||||
NOTIFICATION_ID,
|
||||
get_value_from_dict,
|
||||
init_ip,
|
||||
no_key_cache_key,
|
||||
signature,
|
||||
url_encode_wrapper,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApolloClient:
|
||||
def __init__(
|
||||
self,
|
||||
config_url,
|
||||
app_id,
|
||||
cluster="default",
|
||||
secret="",
|
||||
start_hot_update=True,
|
||||
change_listener=None,
|
||||
_notification_map=None,
|
||||
):
|
||||
# Core routing parameters
|
||||
self.config_url = config_url
|
||||
self.cluster = cluster
|
||||
self.app_id = app_id
|
||||
|
||||
# Non-core parameters
|
||||
self.ip = init_ip()
|
||||
self.secret = secret
|
||||
|
||||
# Check the parameter variables
|
||||
|
||||
# Private control variables
|
||||
self._cycle_time = 5
|
||||
self._stopping = False
|
||||
self._cache = {}
|
||||
self._no_key = {}
|
||||
self._hash = {}
|
||||
self._pull_timeout = 75
|
||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||
self._long_poll_thread = None
|
||||
self._change_listener = change_listener # "add" "delete" "update"
|
||||
if _notification_map is None:
|
||||
_notification_map = {"application": -1}
|
||||
self._notification_map = _notification_map
|
||||
self.last_release_key = None
|
||||
# Private startup method
|
||||
self._path_checker()
|
||||
if start_hot_update:
|
||||
self._start_hot_update()
|
||||
|
||||
# start the heartbeat thread
|
||||
heartbeat = threading.Thread(target=self._heart_beat)
|
||||
heartbeat.daemon = True
|
||||
heartbeat.start()
|
||||
|
||||
def get_json_from_net(self, namespace="application"):
|
||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||
)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"get_json_from_net load configs failed, body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
data = data["configurations"]
|
||||
return_data = {CONFIGURATIONS: data}
|
||||
return return_data
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in get_json_from_net")
|
||||
return None
|
||||
|
||||
def get_value(self, key, default_val=None, namespace="application"):
|
||||
try:
|
||||
# read memory configuration
|
||||
namespace_cache = self._cache.get(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
if no_key in self._no_key:
|
||||
return default_val
|
||||
|
||||
# read the network configuration
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
val = get_value_from_dict(namespace_data, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return val
|
||||
|
||||
# read the file configuration
|
||||
namespace_cache = self._get_local_cache(namespace)
|
||||
val = get_value_from_dict(namespace_cache, key)
|
||||
if val is not None:
|
||||
self._update_cache_and_file(namespace_cache, namespace)
|
||||
return val
|
||||
|
||||
# If all of them are not obtained, the default value is returned
|
||||
# and the local cache is set to None
|
||||
self._set_local_cache_none(namespace, key)
|
||||
return default_val
|
||||
except Exception:
|
||||
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
|
||||
return default_val
|
||||
|
||||
# Set the key of a namespace to none, and do not set default val
|
||||
# to ensure the real-time correctness of the function call.
|
||||
# If the user does not have the same default val twice
|
||||
# and the default val is used here, there may be a problem.
|
||||
def _set_local_cache_none(self, namespace, key):
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
self._no_key[no_key] = key
|
||||
|
||||
def _start_hot_update(self):
|
||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||
# when the main thread is launched.
|
||||
self._long_poll_thread.daemon = True
|
||||
self._long_poll_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
logger.info("Stopping listener...")
|
||||
|
||||
# Call the set callback function, and if it is abnormal, try it out
|
||||
def _call_listener(self, namespace, old_kv, new_kv):
|
||||
if self._change_listener is None:
|
||||
return
|
||||
if old_kv is None:
|
||||
old_kv = {}
|
||||
if new_kv is None:
|
||||
new_kv = {}
|
||||
try:
|
||||
for key in old_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if new_value is None:
|
||||
# If newValue is empty, it means key, and the value is deleted.
|
||||
self._change_listener("delete", namespace, key, old_value)
|
||||
continue
|
||||
if new_value != old_value:
|
||||
self._change_listener("update", namespace, key, new_value)
|
||||
continue
|
||||
for key in new_kv:
|
||||
new_value = new_kv.get(key)
|
||||
old_value = old_kv.get(key)
|
||||
if old_value is None:
|
||||
self._change_listener("add", namespace, key, new_value)
|
||||
except BaseException as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _path_checker(self):
|
||||
if not os.path.isdir(self._cache_file_path):
|
||||
makedirs_wrapper(self._cache_file_path)
|
||||
|
||||
# update the local cache and file cache
|
||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||
# update the local cache
|
||||
self._cache[namespace] = namespace_data
|
||||
# update the file cache
|
||||
new_string = json.dumps(namespace_data)
|
||||
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
|
||||
if self._hash.get(namespace) == new_hash:
|
||||
pass
|
||||
else:
|
||||
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
|
||||
file_path.write_text(new_string)
|
||||
self._hash[namespace] = new_hash
|
||||
|
||||
# get the configuration from the local file
|
||||
def _get_local_cache(self, namespace="application"):
|
||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||
if os.path.isfile(cache_file_path):
|
||||
with open(cache_file_path) as f:
|
||||
result = json.loads(f.readline())
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _long_poll(self):
|
||||
notifications = []
|
||||
for key in self._cache:
|
||||
namespace_data = self._cache[key]
|
||||
notification_id = -1
|
||||
if NOTIFICATION_ID in namespace_data:
|
||||
notification_id = self._cache[key][NOTIFICATION_ID]
|
||||
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
|
||||
try:
|
||||
# if the length is 0 it is returned directly
|
||||
if len(notifications) == 0:
|
||||
return
|
||||
url = "{}/notifications/v2".format(self.config_url)
|
||||
params = {
|
||||
"appId": self.app_id,
|
||||
"cluster": self.cluster,
|
||||
"notifications": json.dumps(notifications, ensure_ascii=False),
|
||||
}
|
||||
param_str = url_encode_wrapper(params)
|
||||
url = url + "?" + param_str
|
||||
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
|
||||
http_code = code
|
||||
if http_code == 304:
|
||||
logger.debug("No change, loop...")
|
||||
return
|
||||
if http_code == 200:
|
||||
if not body:
|
||||
logger.error(f"_long_poll load configs failed,body is {body}")
|
||||
return
|
||||
data = json.loads(body)
|
||||
for entry in data:
|
||||
namespace = entry[NAMESPACE_NAME]
|
||||
n_id = entry[NOTIFICATION_ID]
|
||||
logger.info("%s has changes: notificationId=%d", namespace, n_id)
|
||||
self._get_net_and_set_local(namespace, n_id, call_change=True)
|
||||
return
|
||||
else:
|
||||
logger.warning("Sleep...")
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
if not namespace_data:
|
||||
return
|
||||
namespace_data[NOTIFICATION_ID] = n_id
|
||||
old_namespace = self._cache.get(namespace)
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
if self._change_listener is not None and call_change and old_namespace:
|
||||
old_kv = old_namespace.get(CONFIGURATIONS)
|
||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||
self._call_listener(namespace, old_kv, new_kv)
|
||||
|
||||
def _listener(self):
|
||||
logger.info("start long_poll")
|
||||
while not self._stopping:
|
||||
self._long_poll()
|
||||
time.sleep(self._cycle_time)
|
||||
logger.info("stopped, long_poll")
|
||||
|
||||
# add the need for endorsement to the header
|
||||
def _sign_headers(self, url):
|
||||
headers = {}
|
||||
if self.secret == "":
|
||||
return headers
|
||||
uri = url[len(self.config_url) : len(url)]
|
||||
time_unix_now = str(int(round(time.time() * 1000)))
|
||||
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
|
||||
headers["Timestamp"] = time_unix_now
|
||||
return headers
|
||||
|
||||
def _heart_beat(self):
|
||||
while not self._stopping:
|
||||
for namespace in self._notification_map:
|
||||
self._do_heart_beat(namespace)
|
||||
time.sleep(60 * 10) # 10分钟
|
||||
|
||||
def _do_heart_beat(self, namespace):
|
||||
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
if code == 200:
|
||||
if not body:
|
||||
logger.error(f"_do_heart_beat load configs failed,body is {body}")
|
||||
return None
|
||||
data = json.loads(body)
|
||||
if self.last_release_key == data["releaseKey"]:
|
||||
return None
|
||||
self.last_release_key = data["releaseKey"]
|
||||
data = data["configurations"]
|
||||
self._update_cache_and_file(data, namespace)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("an error occurred in _do_heart_beat")
|
||||
return None
|
||||
|
||||
def get_all_dicts(self, namespace):
|
||||
namespace_data = self._cache.get(namespace)
|
||||
if namespace_data is None:
|
||||
net_namespace_data = self.get_json_from_net(namespace)
|
||||
if not net_namespace_data:
|
||||
return namespace_data
|
||||
namespace_data = net_namespace_data.get(CONFIGURATIONS)
|
||||
if namespace_data:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return namespace_data
|
||||
@@ -1,41 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from urllib import parse
|
||||
from urllib.error import HTTPError
|
||||
|
||||
# Create an SSL context that allows for a lower level of security
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Create an opener object and pass in a custom SSL context
|
||||
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
|
||||
|
||||
urllib.request.install_opener(opener)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def http_request(url, timeout, headers={}):
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=headers)
|
||||
res = urllib.request.urlopen(request, timeout=timeout)
|
||||
body = res.read().decode("utf-8")
|
||||
return res.code, body
|
||||
except HTTPError as e:
|
||||
if e.code == 304:
|
||||
logger.warning("http_request error,code is 304, maybe you should check secret")
|
||||
return 304, None
|
||||
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
|
||||
raise e
|
||||
|
||||
|
||||
def url_encode(params):
|
||||
return parse.urlencode(params)
|
||||
|
||||
|
||||
def makedirs_wrapper(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -1,51 +0,0 @@
|
||||
import hashlib
|
||||
import socket
|
||||
|
||||
from .python_3x import url_encode
|
||||
|
||||
# define constants
|
||||
CONFIGURATIONS = "configurations"
|
||||
NOTIFICATION_ID = "notificationId"
|
||||
NAMESPACE_NAME = "namespaceName"
|
||||
|
||||
|
||||
# add timestamps uris and keys
|
||||
def signature(timestamp, uri, secret):
|
||||
import base64
|
||||
import hmac
|
||||
|
||||
string_to_sign = "" + timestamp + "\n" + uri
|
||||
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
|
||||
return base64.b64encode(hmac_code).decode()
|
||||
|
||||
|
||||
def url_encode_wrapper(params):
|
||||
return url_encode(params)
|
||||
|
||||
|
||||
def no_key_cache_key(namespace, key):
|
||||
return "{}{}{}".format(namespace, len(namespace), key)
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache, key):
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
return None
|
||||
if key in kv_data:
|
||||
return kv_data[key]
|
||||
return None
|
||||
|
||||
|
||||
def init_ip():
|
||||
ip = ""
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 53))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
if s:
|
||||
s.close()
|
||||
return ip
|
||||
@@ -1,15 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
class RemoteSettingsSource:
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
return value
|
||||
@@ -1,5 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RemoteSettingsSourceName(StrEnum):
|
||||
APOLLO = "apollo"
|
||||
@@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@@ -62,6 +62,7 @@ from .datasets import (
|
||||
external,
|
||||
hit_testing,
|
||||
website,
|
||||
fta_test,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
|
||||
145
api/controllers/console/datasets/fta_test.py
Normal file
145
api/controllers/console/datasets/fta_test.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
from flask import Response
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import text
|
||||
|
||||
from controllers.console import api
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.fta import ComponentFailure, ComponentFailureStats
|
||||
|
||||
|
||||
class FATTestApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("log_process_data", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
print(args["log_process_data"])
|
||||
# Extract the JSON string from the text field
|
||||
json_str = args["log_process_data"].strip("```json\\n").strip("```").strip().replace("\\n", "")
|
||||
log_data = json.loads(json_str)
|
||||
db.session.query(ComponentFailure).delete()
|
||||
for data in log_data:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError("Data must be a dictionary.")
|
||||
|
||||
required_keys = {"Date", "Component", "FailureMode", "Cause", "RepairAction", "Technician"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
raise ValueError(f"Data dictionary must contain the following keys: {required_keys}")
|
||||
|
||||
try:
|
||||
# Clear existing stats
|
||||
component_failure = ComponentFailure(
|
||||
Date=data["Date"],
|
||||
Component=data["Component"],
|
||||
FailureMode=data["FailureMode"],
|
||||
Cause=data["Cause"],
|
||||
RepairAction=data["RepairAction"],
|
||||
Technician=data["Technician"],
|
||||
)
|
||||
db.session.add(component_failure)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Clear existing stats
|
||||
db.session.query(ComponentFailureStats).delete()
|
||||
|
||||
# Insert calculated statistics
|
||||
try:
|
||||
db.session.execute(
|
||||
text("""
|
||||
INSERT INTO component_failure_stats ("Component", "FailureMode", "Cause", "PossibleAction", "Probability", "MTBF")
|
||||
SELECT
|
||||
cf."Component",
|
||||
cf."FailureMode",
|
||||
cf."Cause",
|
||||
cf."RepairAction" as "PossibleAction",
|
||||
COUNT(*) * 1.0 / (SELECT COUNT(*) FROM component_failure WHERE "Component" = cf."Component") AS "Probability",
|
||||
COALESCE(AVG(EXTRACT(EPOCH FROM (next_failure_date::timestamp - cf."Date"::timestamp)) / 86400.0),0)AS "MTBF"
|
||||
FROM (
|
||||
SELECT
|
||||
"Component",
|
||||
"FailureMode",
|
||||
"Cause",
|
||||
"RepairAction",
|
||||
"Date",
|
||||
LEAD("Date") OVER (PARTITION BY "Component", "FailureMode", "Cause" ORDER BY "Date") AS next_failure_date
|
||||
FROM
|
||||
component_failure
|
||||
) cf
|
||||
GROUP BY
|
||||
cf."Component", cf."FailureMode", cf."Cause", cf."RepairAction";
|
||||
""")
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
print(f"Error during stats calculation: {e}")
|
||||
# output format
|
||||
# [
|
||||
# (17, 'Hydraulic system', 'Leak', 'Hose rupture', 'Replaced hydraulic hose', 0.3333333333333333, None),
|
||||
# (18, 'Hydraulic system', 'Leak', 'Seal Wear', 'Replaced the faulty seal', 0.3333333333333333, None),
|
||||
# (19, 'Hydraulic system', 'Pressure drop', 'Fluid leak', 'Replaced hydraulic fluid and seals', 0.3333333333333333, None)
|
||||
# ]
|
||||
|
||||
component_failure_stats = db.session.query(ComponentFailureStats).all()
|
||||
# Convert stats to list of tuples format
|
||||
stats_list = []
|
||||
for stat in component_failure_stats:
|
||||
stats_list.append(
|
||||
(
|
||||
stat.StatID,
|
||||
stat.Component,
|
||||
stat.FailureMode,
|
||||
stat.Cause,
|
||||
stat.PossibleAction,
|
||||
stat.Probability,
|
||||
stat.MTBF,
|
||||
)
|
||||
)
|
||||
return {"data": stats_list}, 200
|
||||
|
||||
|
||||
# generate-fault-tree
|
||||
class GenerateFaultTreeApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("llm_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
entities = args["llm_text"].replace("```", "").replace("\\n", "\n")
|
||||
print(entities)
|
||||
request_data = {"fault_tree_text": entities}
|
||||
url = "https://fta.cognitech-dev.live/generate-fault-tree"
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
response = requests.post(url, json=request_data, headers=headers)
|
||||
print(response.json())
|
||||
return {"data": response.json()}, 200
|
||||
|
||||
|
||||
class ExtractSVGApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("svg_text", nullable=False, required=True, type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
# svg_text = ''.join(args["svg_text"].splitlines())
|
||||
svg_text = args["svg_text"].replace("\n", "")
|
||||
svg_text = svg_text.replace('"', '"')
|
||||
print(svg_text)
|
||||
svg_text_json = json.loads(svg_text)
|
||||
svg_content = svg_text_json.get("data").get("svg_content")[0]
|
||||
svg_content = svg_content.replace("\n", "").replace('"', '"')
|
||||
file_key = "fta_svg/" + "fat.svg"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, svg_content.encode("utf-8"))
|
||||
generator = storage.load(file_key, stream=True)
|
||||
|
||||
return Response(generator, mimetype="image/svg+xml")
|
||||
|
||||
|
||||
api.add_resource(FATTestApi, "/fta/db-handler")
|
||||
api.add_resource(GenerateFaultTreeApi, "/fta/generate-fault-tree")
|
||||
api.add_resource(ExtractSVGApi, "/fta/extract-svg")
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from sqlalchemy import and_
|
||||
@@ -21,17 +20,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(installed_app_list_fields)
|
||||
def get(self):
|
||||
app_id = request.args.get("app_id", default=None, type=str)
|
||||
current_tenant_id = current_user.current_tenant_id
|
||||
|
||||
if app_id:
|
||||
installed_apps = (
|
||||
db.session.query(InstalledApp)
|
||||
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
|
||||
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
installed_apps = [
|
||||
|
||||
@@ -368,7 +368,6 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
labels=args["labels"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
|
||||
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -36,29 +36,6 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
@@ -67,17 +44,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@@ -32,7 +31,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -129,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@@ -319,7 +318,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
@@ -362,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -375,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
@@ -386,29 +387,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
@@ -426,7 +404,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -28,39 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -69,7 +36,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
for resource in metadata["retriever_resources"]:
|
||||
updated_resources.append(
|
||||
{
|
||||
"segment_id": resource.get("segment_id", ""),
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,29 +44,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -44,29 +44,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -30,35 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
@@ -70,20 +41,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
|
||||
@@ -6,7 +6,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
@@ -35,8 +34,7 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
| QueueWorkflowFailedEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@@ -27,7 +26,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -108,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -259,36 +258,36 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
@@ -321,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# FIXME for issue #11221 quick fix maybe have a better solution
|
||||
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
@@ -334,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
@@ -344,30 +345,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
@@ -377,6 +354,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -388,7 +366,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@@ -19,7 +18,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
@@ -27,7 +25,6 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
@@ -35,7 +32,6 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
@@ -180,12 +176,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
@@ -261,36 +253,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
self._publish_event(
|
||||
QueueNodeExceptionEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
@@ -25,14 +25,12 @@ class QueueEvent(StrEnum):
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||
WORKFLOW_FAILED = "workflow_failed"
|
||||
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
|
||||
ITERATION_START = "iteration_start"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_SUCCEEDED = "node_succeeded"
|
||||
NODE_FAILED = "node_failed"
|
||||
NODE_EXCEPTION = "node_exception"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
@@ -115,6 +113,18 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
duration: Optional[float] = None
|
||||
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
Set output
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
@@ -239,17 +249,6 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
exceptions_count: int
|
||||
|
||||
|
||||
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||
exceptions_count: int
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueNodeStartedEvent(AppQueueEvent):
|
||||
@@ -344,37 +343,6 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeExceptionEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_EXCEPTION
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
|
||||
@@ -213,7 +213,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
exceptions_count: Optional[int] = 0
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
|
||||
@@ -110,7 +110,7 @@ class RateLimitGenerator:
|
||||
raise StopIteration
|
||||
try:
|
||||
return next(self.generator)
|
||||
except Exception:
|
||||
except StopIteration:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
@@ -165,55 +164,6 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
@@ -224,7 +174,6 @@ class WorkflowCycleManage:
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
@@ -244,7 +193,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = (
|
||||
@@ -369,7 +318,7 @@ class WorkflowCycleManage:
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
@@ -388,11 +337,7 @@ class WorkflowCycleManage:
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
),
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
||||
@@ -406,11 +351,8 @@ class WorkflowCycleManage:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
@@ -491,7 +433,6 @@ class WorkflowCycleManage:
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -542,10 +483,7 @@ class WorkflowCycleManage:
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_repository
|
||||
@@ -18,6 +20,38 @@ from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def download_to_target_path(f: File, temp_dir: str, /):
|
||||
if f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
suffix = Path(tool_file.file_key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(tool_file.file_key, target_path)
|
||||
return target_path
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
suffix = Path(upload_file.key).suffix
|
||||
target_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
_download_file_to_target_path(upload_file.key, target_path)
|
||||
return target_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _download_file_to_target_path(path: str, target_path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
This function loads the file from storage and ensures it's in bytes format.
|
||||
|
||||
Args:
|
||||
path (str): The path to the file in storage.
|
||||
target_path (str): The path to the target file.
|
||||
Raises:
|
||||
ValueError: If the loaded file is not a bytes object.
|
||||
"""
|
||||
storage.download(path, target_path)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
@@ -141,7 +175,7 @@ def _to_url(f: File, /):
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
|
||||
@@ -24,12 +24,6 @@ BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if "allow_redirects" in kwargs:
|
||||
allow_redirects = kwargs.pop("allow_redirects")
|
||||
@@ -70,7 +64,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
@@ -91,7 +91,7 @@ class XinferenceProvider(Provider):
|
||||
"""
|
||||
```
|
||||
|
||||
也可以直接抛出对应 Errors,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
也可以直接抛出对应Erros,并做如下定义,这样在之后的调用中可以直接抛出`InvokeConnectionError`等异常。
|
||||
|
||||
```python
|
||||
@property
|
||||
|
||||
@@ -10,7 +10,6 @@ from core.model_runtime.entities.llm_entities import (
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
@@ -106,11 +105,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_dict = {"role": "user", "content": message_content.data}
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
raise ValueError("Content object type not support image_url")
|
||||
raise ValueError("User message content must be str")
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
|
||||
@@ -16,7 +16,6 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
|
||||
def get_bedrock_client(service_name, credentials=None):
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
aws_access_key_id = credentials["aws_access_key_id"]
|
||||
aws_secret_access_key = credentials["aws_secret_access_key"]
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
service_name=service_name,
|
||||
config=client_config,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
else:
|
||||
# use iam without aksk to call
|
||||
client = boto3.client(service_name=service_name, config=client_config)
|
||||
|
||||
return client
|
||||
@@ -6,7 +6,6 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -6,7 +6,6 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -40,7 +40,6 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
@@ -174,7 +173,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param stream: is stream response
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
|
||||
bedrock_client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
region_name=credentials["aws_region"],
|
||||
)
|
||||
|
||||
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -6,7 +6,6 @@ features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 300000
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
- amazon.rerank-v1
|
||||
- cohere.rerank-v3-5
|
||||
@@ -1,4 +0,0 @@
|
||||
model: amazon.rerank-v1:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@@ -1,4 +0,0 @@
|
||||
model: cohere.rerank-v3-5:0
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@@ -1,139 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
|
||||
class BedrockRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=docs)
|
||||
|
||||
# initialize client
|
||||
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
|
||||
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
|
||||
text_sources = []
|
||||
for text in docs:
|
||||
text_sources.append(
|
||||
{
|
||||
"type": "INLINE",
|
||||
"inlineDocumentSource": {
|
||||
"type": "TEXT",
|
||||
"textDocument": {
|
||||
"text": text,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": top_n,
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
},
|
||||
}
|
||||
response = bedrock_runtime.rerank(
|
||||
queries=queries, sources=text_sources, rerankingConfiguration=rerankingConfiguration
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(response["results"]):
|
||||
# format document
|
||||
index = result["index"]
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=docs[index],
|
||||
score=result["relevanceScore"],
|
||||
)
|
||||
|
||||
# score threshold check
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [],
|
||||
}
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import (
|
||||
ClientError,
|
||||
EndpointConnectionError,
|
||||
@@ -23,7 +25,6 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,7 +48,14 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
|
||||
bedrock_runtime = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
config=client_config,
|
||||
aws_access_key_id=credentials.get("aws_access_key_id"),
|
||||
aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
token_usage = 0
|
||||
|
||||
@@ -2,4 +2,3 @@
|
||||
- rerank-english-v3.0
|
||||
- rerank-multilingual-v2.0
|
||||
- rerank-multilingual-v3.0
|
||||
- rerank-v3.5
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
model: rerank-v3.5
|
||||
model_type: rerank
|
||||
model_properties:
|
||||
context_size: 5120
|
||||
@@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
model: gemini-2.0-flash-exp
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Exp
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,38 +0,0 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,5 +1,4 @@
|
||||
- llama-3.1-405b-reasoning
|
||||
- llama-3.3-70b-versatile
|
||||
- llama-3.1-70b-versatile
|
||||
- llama-3.1-8b-instant
|
||||
- llama3-70b-8192
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
model: gemma-7b-it
|
||||
label:
|
||||
zh_Hans: Gemma 7B Instruction Tuned
|
||||
en_US: Gemma 7B Instruction Tuned
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,25 +0,0 @@
|
||||
model: gemma2-9b-it
|
||||
label:
|
||||
zh_Hans: Gemma 2 9B Instruction Tuned
|
||||
en_US: Gemma 2 9B Instruction Tuned
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,8 +1,7 @@
|
||||
model: llama-3.1-70b-versatile
|
||||
deprecated: true
|
||||
label:
|
||||
zh_Hans: Llama-3.1-70b-versatile (DEPRECATED)
|
||||
en_US: Llama-3.1-70b-versatile (DEPRECATED)
|
||||
zh_Hans: Llama-3.1-70b-versatile
|
||||
en_US: Llama-3.1-70b-versatile
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: llama-3.2-11b-text-preview
|
||||
deprecated: true
|
||||
label:
|
||||
zh_Hans: Llama 3.2 11B Text (Preview)
|
||||
en_US: Llama 3.2 11B Text (Preview)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
model: llama-3.2-90b-text-preview
|
||||
depraceted: true
|
||||
label:
|
||||
zh_Hans: Llama 3.2 90B Text (Preview)
|
||||
en_US: Llama 3.2 90B Text (Preview)
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
model: llama-3.3-70b-specdec
|
||||
label:
|
||||
zh_Hans: Llama 3.3 70B Specdec
|
||||
en_US: Llama 3.3 70B Specdec
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@@ -1,25 +0,0 @@
|
||||
model: llama-3.3-70b-versatile
|
||||
label:
|
||||
zh_Hans: Llama 3.3 70B Versatile
|
||||
en_US: Llama 3.3 70B Versatile
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@@ -1,25 +0,0 @@
|
||||
model: llama3-groq-70b-8192-tool-use-preview
|
||||
label:
|
||||
zh_Hans: Llama3-groq-70b-8192-tool-use (PREVIEW)
|
||||
en_US: Llama3-groq-70b-8192-tool-use (PREVIEW)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.08'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -35,7 +35,6 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
"abab7-chat-preview": MinimaxChatCompletionPro,
|
||||
"abab6.5t-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5-chat": MinimaxChatCompletionPro,
|
||||
"abab6-chat": MinimaxChatCompletionPro,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
- pixtral-large-latest
|
||||
- pixtral-large-2411
|
||||
- pixtral-12b-2409
|
||||
- codestral-latest
|
||||
- mistral-embed
|
||||
|
||||
@@ -5,7 +5,6 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
@@ -22,7 +21,7 @@ parameter_rules:
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: safe_prompt
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
model: pixtral-large-2411
|
||||
label:
|
||||
zh_Hans: pixtral-large-2411
|
||||
en_US: pixtral-large-2411
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -1,52 +0,0 @@
|
||||
model: pixtral-large-latest
|
||||
label:
|
||||
zh_Hans: pixtral-large-latest
|
||||
en_US: pixtral-large-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -181,11 +181,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {"model": model, "stream": stream}
|
||||
|
||||
if format_schema := model_parameters.pop("format", None):
|
||||
try:
|
||||
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
|
||||
if "format" in model_parameters:
|
||||
data["format"] = model_parameters["format"]
|
||||
del model_parameters["format"]
|
||||
|
||||
if "keep_alive" in model_parameters:
|
||||
data["keep_alive"] = model_parameters["keep_alive"]
|
||||
@@ -735,12 +733,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
|
||||
type=ParameterType.TEXT,
|
||||
default="json",
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
|
||||
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
|
||||
en_US="the format to return a response in. Currently the only accepted value is json.",
|
||||
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
|
||||
@@ -478,10 +478,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
# handle the error here. for issue #11629
|
||||
if chunk_json.get("error") and chunk_json.get("choices") is None:
|
||||
raise ValueError(chunk_json.get("error"))
|
||||
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
- Tencent/Hunyuan-A52B-Instruct
|
||||
- Qwen/QwQ-32B-Preview
|
||||
- Qwen/Qwen2.5-72B-Instruct
|
||||
- Qwen/Qwen2.5-32B-Instruct
|
||||
- Qwen/Qwen2.5-14B-Instruct
|
||||
@@ -20,7 +19,6 @@
|
||||
- 01-ai/Yi-1.5-6B-Chat
|
||||
- internlm/internlm2_5-20b-chat
|
||||
- internlm/internlm2_5-7b-chat
|
||||
- meta-llama/Llama-3.3-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-405B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-70B-Instruct
|
||||
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
model: meta-llama/Llama-3.3-70B-Instruct
|
||||
label:
|
||||
en_US: meta-llama/Llama-3.3-70B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@@ -1,53 +0,0 @@
|
||||
model: Qwen/QwQ-32B-Preview
|
||||
label:
|
||||
en_US: Qwen/QwQ-32B-Preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@@ -59,6 +59,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
|
||||
@@ -59,6 +59,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
|
||||
@@ -58,6 +58,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
|
||||
@@ -59,6 +59,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
|
||||
@@ -59,6 +59,8 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。
|
||||
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: repetition_penalty
|
||||
required: false
|
||||
type: float
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
model: gemini-2.0-flash-exp
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Exp
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -104,14 +104,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
# use Anthropic official SDK references
|
||||
# - https://github.com/anthropics/anthropic-sdk-python
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
token = ""
|
||||
|
||||
# get access token from service account credential
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
|
||||
request = google.auth.transport.requests.Request()
|
||||
credentials.refresh(request)
|
||||
@@ -479,11 +478,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
|
||||
if stop:
|
||||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
|
||||
@@ -48,11 +48,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
@@ -101,11 +100,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
service_account_key = credentials.get("vertex_service_account_key", "")
|
||||
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
|
||||
project_id = credentials["vertex_project_id"]
|
||||
location = credentials["vertex_location"]
|
||||
if service_account_key:
|
||||
service_account_info = json.loads(base64.b64decode(service_account_key))
|
||||
if service_account_info:
|
||||
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
|
||||
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
|
||||
else:
|
||||
|
||||
@@ -68,12 +68,7 @@ class MaaSClient(MaasService):
|
||||
content = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": message_content.data,
|
||||
}
|
||||
)
|
||||
raise ValueError("Content object type only support image_url")
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
model: grok-2-1212
|
||||
label:
|
||||
en_US: grok-2-1212
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -1,64 +0,0 @@
|
||||
model: grok-2-vision-1212
|
||||
label:
|
||||
en_US: grok-2-vision-1212
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
@@ -1,6 +1,6 @@
|
||||
model: grok-beta
|
||||
label:
|
||||
en_US: grok-beta
|
||||
en_US: Grok Beta
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
model: grok-vision-beta
|
||||
label:
|
||||
en_US: grok-vision-beta
|
||||
en_US: Grok Vision Beta
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user