Compare commits

..

1 Commits

Author SHA1 Message Date
jyong
ea5e8ee7cc calculate tokens 2024-07-15 18:39:25 +08:00
304 changed files with 1772 additions and 12386 deletions

View File

@@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
### 5. Visit dify in your browser

View File

@@ -2,17 +2,17 @@
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎提供任何反馈以供我们改进。
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
在许可方面,请花一分钟阅读我们简短的 [许可证和贡献者协议](./LICENSE)。社区还遵守 [行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./LICENSE)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类:
### 功能请求:
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss) 制作了一个很好的 [功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
@@ -20,44 +20,45 @@
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
| 团队成员 | 工作范围 |
| Member | Scope |
| ------------------------------------------------------------ | ---------------------------------------------------- |
| [@yeuoly](https://github.com/Yeuoly) | 架构 Agents |
| [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
| [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
| [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
事项优先级:
How we prioritize:
| 功能类型 | 优先级 |
| Feature Type | Priority |
| ------------------------------------------------------------ | --------------- |
| 被团队成员标记为高优先级的功能 | 高优先级 |
| [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
| 非核心功能和小幅改进 | 低优先级 |
| 有价值当不紧急 | 未来功能 |
| High-Priority Features as being labeled by a team member | High Priority |
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
| Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature |
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
### 其他任何事情例如bug报告、性能优化、拼写错误更正
* 立即开始编码。
事项优先级:
How we prioritize:
| Issue 类型 | 优先级 |
| Issue Type | Priority |
| ------------------------------------------------------------ | --------------- |
| 核心功能的 Bugs例如无法登录、应用无法工作、安全漏洞 | 紧急 |
| 非紧急 bugs, 性能提升 | 中等优先级 |
| 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
| Non-critical bugs, performance boosts | Medium Priority |
| Minor fixes (typos, confusing but working UI) | Low Priority |
## 安装
以下是设置 Dify 进行开发的步骤:
以下是设置Dify进行开发的步骤
### 1. Fork 该仓库
### 1. Fork该仓库
### 2. 克隆仓库
从终端克隆代码仓库:
从终端克隆fork的仓库:
```
git clone git@github.com:<github_username>/dify.git
@@ -75,72 +76,72 @@ Dify 依赖以下工具和库:
### 4. 安装
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。
### 5. 在浏览器中访问 Dify
### 5. 在浏览器中访问Dify
为了验证您的设置,打开浏览器并访问 [http://localhost:3000](http://localhost:3000)(默认或您自定义的 URL 和端口)。现在您应该看到 Dify 正在运行。
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)默认或您自定义的URL和端口。现在您应该看到Dify正在运行。
## 开发
如果您要添加模型提供程序,请参考 [此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
如果您要向 AgentWorkflow 添加工具提供程序,请参考 [此指南](./api/core/tools/README.md)。
如果您要向AgentWorkflow添加工具提供程序请参考[此指南](./api/core/tools/README.md)。
为了帮助您快速了解您的贡献在哪个部分,以下是 Dify 后端和前端的简要注释大纲:
为了帮助您快速了解您的贡献在哪个部分以下是Dify后端和前端的简要注释大纲
### 后端
Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsprojects.com/en/3.0.x/) 框架。它使用 [SQLAlchemy](https://www.sqlalchemy.org/) 作为 ORM使用 [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) 作为任务队列。授权逻辑通过 Flask-login 进行处理。
Dify的后端使用Python编写使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。
```
[api/]
├── constants // 用于整个代码库的常量设置。
├── controllers // API 路由定义和请求处理逻辑。
├── core // 核心应用编排、模型集成和工具。
├── docker // Docker 和容器化相关配置。
├── events // 事件处理和处理。
├── extensions // 与第三方框架/平台的扩展。
├── fields // 用于序列化/封装的字段定义。
├── libs // 可重用的库和助手。
├── migrations // 数据库迁移脚本。
├── models // 数据库模型和架构定义。
├── services // 指定业务逻辑。
├── storage // 私钥存储。
├── tasks // 异步任务和后台作业的处理。
├── constants // Constant settings used throughout code base.
├── controllers // API route definitions and request handling logic.
├── core // Core application orchestration, model integrations, and tools.
├── docker // Docker & containerization related configurations.
├── events // Event handling and processing
├── extensions // Extensions with 3rd party frameworks/platforms.
├── fields // field definitions for serialization/marshalling.
├── libs // Reusable libraries and helpers.
├── migrations // Scripts for database migration.
├── models // Database models & schema definitions.
├── services // Specifies business logic.
├── storage // Private key storage.
├── tasks // Handling of async tasks and background jobs.
└── tests
```
### 前端
该网站使用基于 Typescript[Next.js](https://nextjs.org/) 模板进行引导,并使用 [Tailwind CSS](https://tailwindcss.com/) 进行样式设计。[React-i18next](https://react.i18next.com/) 用于国际化。
该网站使用基于Typescript[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。
```
[web/]
├── app // 布局、页面和组件
│ ├── (commonLayout) // 整个应用通用的布局
│ ├── (shareLayout) // 在特定会话中共享的布局
│ ├── activate // 激活页面
│ ├── components // 页面和布局共享的组件
│ ├── install // 安装页面
│ ├── signin // 登录页面
│ └── styles // 全局共享的样式
├── assets // 静态资源
├── bin // 构建步骤运行的脚本
├── config // 可调整的设置和选项
├── context // 应用中不同部分使用的共享上下文
├── dictionaries // 语言特定的翻译文件
├── docker // 容器配置
├── hooks // 可重用的钩子
├── i18n // 国际化配置
├── models // 描述数据模型和 API 响应的形状
├── public // favicon 等元资源
├── service // 定义 API 操作的形状
├── app // layouts, pages, and components
│ ├── (commonLayout) // common layout used throughout the app
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
│ ├── activate // activate page
│ ├── components // shared by pages and layouts
│ ├── install // install page
│ ├── signin // signin page
│ └── styles // globally shared styles
├── assets // Static assets
├── bin // scripts ran at build step
├── config // adjustable settings and options
├── context // shared contexts used by different portions of the app
├── dictionaries // Language-specific translate files
├── docker // container configurations
├── hooks // Reusable hooks
├── i18n // Internationalization configuration
├── models // describes data models & shapes of API responses
├── public // meta assets like favicon
├── service // specifies shapes of API actions
├── test
├── types // 函数参数和返回值的描述
└── utils // 共享的实用函数
├── types // descriptions of function params and return values
└── utils // Shared utility functions
```
## 提交你的 PR

View File

@@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
### 5. ブラウザで dify にアクセスする

View File

@@ -256,7 +256,3 @@ WORKFLOW_CALL_MAX_DEPTH=5
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1

View File

@@ -23,7 +23,6 @@ class SecurityConfig(BaseSettings):
default=24,
)
class AppExecutionConfig(BaseSettings):
"""
App Execution configs
@@ -436,13 +435,6 @@ class ImageFormatConfig(BaseSettings):
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description='the time of the celery scheduler, default to 1 day',
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -470,6 +462,5 @@ class FeatureConfig(
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
):
pass

View File

@@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
default=False,
)
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description='',
default=None,
)

View File

@@ -1,3 +1,4 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
@@ -7,32 +8,32 @@ class MyScaleConfig(BaseModel):
MyScale configs
"""
MYSCALE_HOST: str = Field(
MYSCALE_HOST: Optional[str] = Field(
description='MyScale host',
default='localhost',
default=None,
)
MYSCALE_PORT: PositiveInt = Field(
MYSCALE_PORT: Optional[PositiveInt] = Field(
description='MyScale port',
default=8123,
)
MYSCALE_USER: str = Field(
MYSCALE_USER: Optional[str] = Field(
description='MyScale user',
default='default',
default=None,
)
MYSCALE_PASSWORD: str = Field(
MYSCALE_PASSWORD: Optional[str] = Field(
description='MyScale password',
default='',
default=None,
)
MYSCALE_DATABASE: str = Field(
MYSCALE_DATABASE: Optional[str] = Field(
description='MyScale database name',
default='default',
default=None,
)
MYSCALE_FTS_PARAMS: str = Field(
MYSCALE_FTS_PARAMS: Optional[str] = Field(
description='MyScale fts index parameters',
default='',
default=None,
)

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description='Dify version',
default='0.6.14',
default='0.6.13',
)
COMMIT_SHA: str = Field(

View File

@@ -15,7 +15,6 @@ from fields.app_fields import (
app_pagination_fields,
)
from libs.login import login_required
from services.app_dsl_service import AppDslService
from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@@ -98,42 +97,8 @@ class AppImportApi(Resource):
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id,
data=args['data'],
args=args,
account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check('apps')
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str, location='json')
parser.add_argument('description', type=str, location='json')
parser.add_argument('icon', type=str, location='json')
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id,
url=args['url'],
args=args,
account=current_user
)
app_service = AppService()
app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user)
return app, 201
@@ -212,13 +177,9 @@ class AppCopyApi(Resource):
parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model)
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id,
data=data,
args=args,
account=current_user
)
app_service = AppService()
data = app_service.export_app(app_model)
app = app_service.import_app(current_user.current_tenant_id, data, args, current_user)
return app, 201
@@ -234,8 +195,10 @@ class AppExportApi(Resource):
if not current_user.is_editor:
raise Forbidden()
app_service = AppService()
return {
"data": AppDslService.export_dsl(app_model=app_model)
"data": app_service.export_app(app_model)
}
@@ -359,7 +322,6 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
api.add_resource(AppApi, '/apps/<uuid:app_id>')
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')

View File

@@ -20,7 +20,6 @@ from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required
from models.model import App, AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService
@@ -129,7 +128,8 @@ class DraftWorkflowImportApi(Resource):
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
workflow_service = WorkflowService()
workflow = workflow_service.import_draft_workflow(
app_model=app_model,
data=args['data'],
account=current_user

View File

@@ -545,15 +545,15 @@ class DatasetRetrievalSettingApi(Resource):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:
@@ -569,15 +569,15 @@ class DatasetRetrievalSettingMockApi(Resource):
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
RetrievalMethod.SEMANTIC_SEARCH
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value,
RetrievalMethod.SEMANTIC_SEARCH,
RetrievalMethod.FULL_TEXT_SEARCH,
RetrievalMethod.HYBRID_SEARCH,
]
}
case _:

View File

@@ -349,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document = self.get_document(dataset_id, document_id)
if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError()
indexing_runner.calculate_tokens(document)
data_process_rule = document.dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()

View File

@@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
)
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
last_segment = DocumentSegment.query.get(str(last_id))
if last_segment:
query = query.filter(
DocumentSegment.position > last_segment.position)

View File

@@ -78,12 +78,10 @@ class ChatTextApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
@@ -97,8 +95,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
app_model=app_model,
message_id=message_id,
voice=voice,
text=text
voice=voice
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:

View File

@@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
member = db.session.get(Account, str(member_id))
member = Account.query.get(str(member_id))
if not member:
abort(404)

View File

@@ -3,9 +3,8 @@ from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from flask import abort, request
from flask import abort, current_app, request
from configs import dify_config
from extensions.ext_database import db
from models.model import EndUser
@@ -13,12 +12,12 @@ from models.model import EndUser
def inner_api_only(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
if not current_app.config['INNER_API']:
abort(404)
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get('X-Inner-Api-Key')
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
abort(404)
return view(*args, **kwargs)
@@ -29,7 +28,7 @@ def inner_api_only(view):
def inner_api_user_auth(view):
@wraps(view)
def decorated(*args, **kwargs):
if not dify_config.INNER_API:
if not current_app.config['INNER_API']:
return view(*args, **kwargs)
# get header 'X-Inner-Api-Key'

View File

@@ -1,7 +1,7 @@
from flask import current_app
from flask_restful import Resource, fields, marshal_with
from configs import dify_config
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@@ -78,7 +78,7 @@ class AppParameterApi(Resource):
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': {
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -76,12 +76,10 @@ class TextApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
@@ -89,15 +87,15 @@ class TextApi(Resource):
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
else:
try:
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception:
voice = None
response = AudioService.transcript_tts(
app_model=app_model,
message_id=message_id,
end_user=end_user.external_user_id,
voice=voice,
text=text
voice=voice
)
return response

View File

@@ -1,6 +1,6 @@
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -21,43 +21,14 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class WorkflowRunApi(Resource):
workflow_run_fields = {
'id': fields.String,
'workflow_id': fields.String,
'status': fields.String,
'inputs': fields.Raw,
'outputs': fields.Raw,
'error': fields.String,
'total_steps': fields.Integer,
'total_tokens': fields.Integer,
'created_at': fields.DateTime,
'finished_at': fields.DateTime,
'elapsed_time': fields.Float,
}
@validate_app_token
@marshal_with(workflow_run_fields)
def get(self, app_model: App, workflow_id: str):
"""
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
return workflow_run
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
"""
@@ -117,5 +88,5 @@ class WorkflowTaskStopApi(Resource):
}
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
api.add_resource(WorkflowRunApi, '/workflows/run')
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')

View File

@@ -1,6 +1,6 @@
from flask import current_app
from flask_restful import Resource
from configs import dify_config
from controllers.service_api import api
@@ -9,7 +9,7 @@ class IndexApi(Resource):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
"server_version": dify_config.CURRENT_VERSION,
"server_version": current_app.config['CURRENT_VERSION']
}

View File

@@ -1,6 +1,6 @@
from flask import current_app
from flask_restful import fields, marshal_with
from configs import dify_config
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
@@ -75,7 +75,7 @@ class AppParameterApi(WebApiResource):
"transfer_methods": ["remote_url", "local_file"]
}}),
'system_parameters': {
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -74,12 +74,10 @@ class TextApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument('voice', type=str, location='json')
parser.add_argument('text', type=str, location='json')
parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args()
message_id = args.get('message_id', None)
text = args.get('text', None)
message_id = args.get('message_id')
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict):
@@ -96,8 +94,7 @@ class TextApi(WebApiResource):
app_model=app_model,
message_id=message_id,
end_user=end_user.external_user_id,
voice=voice,
text=text
voice=voice
)
return response

View File

@@ -1,8 +1,8 @@
from flask import current_app
from flask_restful import fields, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.web import api
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
@@ -84,7 +84,7 @@ class AppSiteInfo:
self.can_replace_logo = can_replace_logo
if can_replace_logo:
base_url = dify_config.FILES_URL
base_url = current_app.config.get('FILES_URL')
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
self.custom_config = {

View File

@@ -255,12 +255,6 @@ class AdvancedChatAppRunner(AppRunner):
)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),

View File

@@ -214,6 +214,61 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:
"""
Estimate the indexing for the document.
"""
embedding_model_instance = None
if dataset_id:
dataset = Dataset.query.filter_by(
id=dataset_id
).first()
if not dataset:
raise ValueError('Dataset not found.')
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
if dataset.embedding_model_provider:
embedding_model_instance = self.model_manager.get_model_instance(
tenant_id=tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
else:
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
if indexing_technique == 'high_quality':
embedding_model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
indexing_technique: str = 'economy') -> dict:

View File

@@ -64,7 +64,6 @@ User Input:
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
"The output must be an array in JSON format following the specified schema:\n"
"[\"question1\",\"question2\",\"question3\"]\n"
)

View File

@@ -103,7 +103,7 @@ class TokenBufferMemory:
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
while curr_message_tokens > max_token_limit and prompt_messages:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages

View File

@@ -27,9 +27,9 @@ parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
default: 8192
default: 4096
min: 1
max: 8192
max: 4096
- name: response_format
use_template: response_format
pricing:

View File

@@ -113,11 +113,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if system:
extra_model_kwargs['system'] = system
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if tools:
extra_model_kwargs['tools'] = [
self._transform_tool_prompt(tool) for tool in tools
@@ -126,7 +121,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
)
@@ -136,7 +130,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
)
@@ -145,7 +138,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,

View File

@@ -71,9 +71,6 @@ model_credential_schema:
- label:
en_US: '2024-02-01'
value: '2024-02-01'
- label:
en_US: '2024-06-01'
value: '2024-06-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here

View File

@@ -66,10 +66,6 @@ provider_credential_schema:
label:
en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: eu-west-2
label:
en_US: Eu west London (London)
zh_Hans: 欧洲西部 (伦敦)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)

View File

@@ -48,28 +48,6 @@ logger = logging.getLogger(__name__)
class BedrockLargeLanguageModel(LargeLanguageModel):
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
CONVERSE_API_ENABLED_MODEL_INFO=[
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
]
@staticmethod
def _find_model_info(model_id):
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
if model_id.startswith(model['prefix']):
return model
logger.info(f"current model id: {model_id} did not support by Converse API")
return None
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -88,12 +66,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
model_info= BedrockLargeLanguageModel._find_model_info(model)
if model_info:
model_info['model'] = model
# invoke models via boto3 converse API
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# TODO: consolidate different invocation methods for models based on base model capabilities
# invoke anthropic models via boto3 client
if "anthropic" in model:
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
# invoke Cohere models via boto3 client
if "cohere.command-r" in model:
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@@ -175,12 +151,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
"""
Invoke large language model with converse API
Invoke Anthropic large language model
:param model_info: model information
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
@@ -197,24 +173,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
parameters = {
'modelId': model_info['model'],
'modelId': model,
'messages': prompt_message_dicts,
'inferenceConfig': inference_config,
'additionalModelRequestFields': additional_model_fields,
}
if model_info['support_system_prompts'] and system and len(system) > 0:
if system and len(system) > 0:
parameters['system'] = system
if model_info['support_tool_use'] and tools:
if tools:
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
if stream:
response = bedrock_client.converse_stream(**parameters)
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
else:
response = bedrock_client.converse(**parameters)
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
return self._handle_converse_response(model, credentials, response, prompt_messages)
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage]) -> LLMResult:
@@ -227,30 +203,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param prompt_messages: prompt messages
:return: full response chunk generator result
"""
response_content = response['output']['message']['content']
# transform assistant message to prompt message
if response['stopReason'] == 'tool_use':
tool_calls = []
text, tool_use = self._extract_tool_use(response_content)
tool_call = AssistantPromptMessage.ToolCall(
id=tool_use['toolUseId'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_use['name'],
arguments=json.dumps(tool_use['input'])
)
)
tool_calls.append(tool_call)
assistant_prompt_message = AssistantPromptMessage(
content=text,
tool_calls=tool_calls
)
else:
assistant_prompt_message = AssistantPromptMessage(
content=response_content[0]['text']
)
assistant_prompt_message = AssistantPromptMessage(
content=response['output']['message']['content'][0]['text']
)
# calculate num tokens
if response['usage']:
@@ -273,18 +229,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
)
return result
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
tool_use = {}
text = ''
for item in content:
if 'toolUse' in item:
tool_use = item['toolUse']
elif 'text' in item:
text = item['text']
else:
raise ValueError(f"Got unknown item: {item}")
return text, tool_use
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
prompt_messages: list[PromptMessage], ) -> Generator:
"""
@@ -396,12 +340,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
system = []
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
message.content=message.content.strip()
system.append({"text": message.content})
else:
prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
@@ -502,6 +448,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,

View File

@@ -2,9 +2,6 @@ model: mistral.mistral-large-2402-v1:0
label:
en_US: Mistral Large
model_type: llm
features:
- tool-call
- agent-thought
model_properties:
mode: completion
context_size: 32000

View File

@@ -2,8 +2,6 @@ model: mistral.mistral-small-2402-v1:0
label:
en_US: Mistral Small
model_type: llm
features:
- tool-call
model_properties:
mode: completion
context_size: 32000

View File

@@ -7,7 +7,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature

View File

@@ -7,7 +7,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature

View File

@@ -1,8 +1,6 @@
- gpt-4
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview

View File

@@ -1,44 +0,0 @@
model: gpt-4o-mini-2024-07-18
label:
zh_Hans: gpt-4o-mini-2024-07-18
en_US: gpt-4o-mini-2024-07-18
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@@ -1,44 +0,0 @@
model: gpt-4o-mini
label:
zh_Hans: gpt-4o-mini
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@@ -616,34 +616,30 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'tool_call':
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
message.tool_calls]
elif function_calling_type == 'function_call':
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
# in
# message.tool_calls]
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'tool_call':
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id
}
elif function_calling_type == 'function_call':
message_dict = {
"role": "function",
"content": message.content,
"name": message.tool_call_id
}
# message_dict = {
# "role": "tool",
# "content": message.content,
# "tool_call_id": message.tool_call_id
# }
message_dict = {
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
"content": message.content,
"name": message.tool_call_id
}
else:
raise ValueError(f"Got unknown type {message}")

View File

@@ -1,5 +1,4 @@
- openai/gpt-4o
- openai/gpt-4o-mini
- openai/gpt-4
- openai/gpt-4-32k
- openai/gpt-3.5-turbo

View File

@@ -1,43 +0,0 @@
model: openai/gpt-4o-mini
label:
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "0.15"
output: "0.60"
unit: "0.000001"
currency: USD

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.5 KiB

View File

@@ -1,238 +0,0 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import boto3
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
usage = self._calc_response_usage(model, credentials, 0, 0)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
except Exception as e:
raise self._transform_invoke_error(e)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature',
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
),
),
ParameterRule(
name='top_p',
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
),
ParameterRule(
name='max_tokens',
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=credentials.get('context_length', 2048),
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
features = []
support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)
support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=features,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)
return entity

View File

@@ -1,190 +0,0 @@
import json
import logging
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
"""
sagemaker_client: Any = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
inputs = [query_input]*len(docs)
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=rerank_endpoint,
Body=json.dumps(
{
"inputs": inputs,
"docs": docs
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
scores = json_obj['scores']
return scores if isinstance(scores, list) else [scores]
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
line = 0
try:
if len(docs) == 0:
return RerankResult(
model=model,
docs=docs
)
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
candidate_docs = []
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
for idx in range(len(scores)):
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
line = 3
rerank_documents = []
for idx, result in enumerate(candidate_docs):
rerank_document = RerankDocument(
index=idx,
text=result.get('content'),
score=result.get('score', -100.0)
)
if score_threshold is not None:
if rerank_document.score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@@ -1,17 +0,0 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

View File

@@ -1,125 +0,0 @@
provider: sagemaker
label:
zh_Hans: Sagemaker
en_US: Sagemaker
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
description:
en_US: Customized model on Sagemaker
zh_Hans: Sagemaker上的私有化部署的模型
background: "#ECE9E3"
help:
title:
en_US: How to deploy customized model on Sagemaker
zh_Hans: 如何在Sagemaker上的私有化部署的模型
url:
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
type: text-input
required: true
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: aws_access_key_id
required: false
label:
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
type: secret-input
placeholder:
en_US: Enter your Access Key
zh_Hans: 在此输入您的 Access Key
- variable: aws_secret_access_key
required: false
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
type: secret-input
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 在此输入您的 Secret Access Key
- variable: aws_region
required: false
label:
en_US: AWS Region
zh_Hans: AWS 地区
type: select
default: us-east-1
options:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: 美国东部 (弗吉尼亚北部)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: 美国西部 (俄勒冈州)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: 亚太地区 (新加坡)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: 亚太地区 (东京)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- value: ap-southeast-2
label:
en_US: Asia Pacific (Sydney)
zh_Hans: 亚太地区 (悉尼)
- value: cn-north-1
label:
en_US: AWS Beijing (cn-north-1)
zh_Hans: 中国北京 (cn-north-1)
- value: cn-northwest-1
label:
en_US: AWS Ningxia (cn-northwest-1)
zh_Hans: 中国宁夏 (cn-northwest-1)

View File

@@ -1,214 +0,0 @@
import itertools
import json
import logging
import time
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
BATCH_SIZE = 20
CONTEXT_SIZE=8192
logger = logging.getLogger(__name__)
def batch_generator(generator, batch_size):
while True:
batch = list(itertools.islice(generator, batch_size))
if not batch:
break
yield batch
class SageMakerEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
sagemaker_client: Any = None
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
response_model = sm_client.invoke_endpoint(
EndpointName=endpoint_name,
Body=json.dumps(
{
"inputs": content_list,
"parameters": {},
"is_query" : False,
"instruction" : ''
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
embeddings = json_obj['embeddings']
return embeddings
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
try:
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
line = 3
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
all_embeddings = []
line = 4
for batch in batches:
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
all_embeddings.extend(embeddings)
line = 5
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=0 # It's not SAAS API, usage is meaningless
)
line = 6
return TextEmbeddingResult(
embeddings=all_embeddings,
usage=usage,
model=model
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return 0
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
print("validate_credentials ok....")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
},
parameter_rules=[]
)
return entity

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -1,6 +0,0 @@
- step-1-8k
- step-1-32k
- step-1-128k
- step-1-256k
- step-1v-8k
- step-1v-32k

View File

@@ -1,328 +0,0 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get('function_calling_type') == 'tool_call'
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name='temperature',
use_template='temperature',
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
type=ParameterType.FLOAT,
),
ParameterRule(
name='max_tokens',
use_template='max_tokens',
default=512,
min=1,
max=int(credentials.get('max_tokens', 1024)),
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
type=ParameterType.INT,
),
ParameterRule(
name='top_p',
use_template='top_p',
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
type=ParameterType.FLOAT,
),
]
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
}.intersection(model_schema.features or []):
credentials['function_calling_type'] = 'tool_call'
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append({
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments
}
})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)

View File

@@ -1,25 +0,0 @@
model: step-1-128k
label:
zh_Hans: step-1-128k
en_US: step-1-128k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 128000
pricing:
input: '0.04'
output: '0.20'
unit: '0.001'
currency: RMB

View File

@@ -1,25 +0,0 @@
model: step-1-256k
label:
zh_Hans: step-1-256k
en_US: step-1-256k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 256000
pricing:
input: '0.095'
output: '0.300'
unit: '0.001'
currency: RMB

View File

@@ -1,28 +0,0 @@
model: step-1-32k
label:
zh_Hans: step-1-32k
en_US: step-1-32k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@@ -1,28 +0,0 @@
model: step-1-8k
label:
zh_Hans: step-1-8k
en_US: step-1-8k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8000
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@@ -1,25 +0,0 @@
model: step-1v-32k
label:
zh_Hans: step-1v-32k
en_US: step-1v-32k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@@ -1,25 +0,0 @@
model: step-1v-8k
label:
zh_Hans: step-1v-8k
en_US: step-1v-8k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@@ -1,30 +0,0 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class StepfunProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='step-1-8k',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@@ -1,81 +0,0 @@
provider: stepfun
label:
zh_Hans: 阶跃星辰
en_US: Stepfun
description:
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from stepfun
zh_Hans: 从 stepfun 获取 API Key
url:
en_US: https://platform.stepfun.com/interface-key
supported_model_types:
- llm
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '8192'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '8192'
type: text-input
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not supported
zh_Hans: 不支持
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call

View File

@@ -29,7 +29,7 @@ model_credential_schema:
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000

View File

@@ -35,4 +35,3 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
deprecated: true

View File

@@ -1,4 +1,4 @@
model: ernie-4.0-8k-latest
model: ernie-4.0-8k-Latest
label:
en_US: Ernie-4.0-8K-Latest
model_type: llm

View File

@@ -1,40 +0,0 @@
model: ernie-4.0-turbo-8k-preview
label:
en_US: Ernie-4.0-turbo-8k-preview
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@@ -1,40 +0,0 @@
model: ernie-4.0-turbo-8k
label:
en_US: Ernie-4.0-turbo-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@@ -28,4 +28,3 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -1,30 +0,0 @@
model: ernie-character-8k-0321
label:
en_US: ERNIE-Character-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.95
- name: top_p
use_template: top_p
min: 0
max: 1.0
default: 0.7
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 1024
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0

View File

@@ -28,4 +28,3 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -28,4 +28,3 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -97,7 +97,6 @@ class BaiduAccessToken:
baidu_access_tokens_lock.release()
return token
class ErnieMessage:
class Role(Enum):
USER = 'user'
@@ -138,10 +137,7 @@ class ErnieBotModel:
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
}
function_calling_supports = [
@@ -152,9 +148,7 @@ class ErnieBotModel:
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview'
'ernie-4.0-8k'
]
api_key: str = ''

View File

@@ -453,11 +453,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
api_key = credentials.get('api_key') or "abc"
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
api_key=api_key,
api_key='abc',
max_retries=3,
timeout=60,
)

View File

@@ -44,23 +44,15 @@ class XinferenceRerankModel(RerankModel):
docs=[]
)
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try:
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
)
rerank_documents = []
for idx, result in enumerate(response['results']):
@@ -110,7 +102,7 @@ class XinferenceRerankModel(RerankModel):
if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a rerank model')
self.invoke(
model=model,
credentials=credentials,

View File

@@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
}
def _speech2text_invoke(
self,
model: str,
credentials: dict,
self,
model: str,
credentials: dict,
file: IO[bytes],
language: Optional[str] = None,
prompt: Optional[str] = None,
@@ -121,24 +121,17 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
:return: text for given audio file
"""
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
try:
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
response = handle.transcriptions(
audio=file,
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature
)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
response = handle.transcriptions(
audio=file,
language = language,
prompt = prompt,
response_format = response_format,
temperature = temperature
)
return response["text"]

View File

@@ -43,17 +43,16 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try:
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
embeddings = handle.create_embedding(input=texts)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
raise InvokeServerUnavailableError(e)
"""
for convenience, the response json is like:
class Embedding(TypedDict):
@@ -107,7 +106,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
@@ -118,7 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
@@ -152,7 +151,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
KeyError
]
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@@ -187,7 +186,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(

View File

@@ -32,7 +32,7 @@ model_credential_schema:
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 http://192.168.1.100:9997
@@ -46,12 +46,3 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid
- variable: api_key
label:
zh_Hans: API密钥
en_US: API key
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的API密钥
en_US: Enter the api key

View File

@@ -20,7 +20,7 @@ class ZhipuaiProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='glm-4',
model='chatglm_turbo',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
from ..core._http_client import make_user_request_input
from ..types.image import ImagesResponded
@@ -28,9 +28,7 @@ class Images(BaseAPI):
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
@@ -48,12 +46,9 @@ class Images(BaseAPI):
"size": size,
"style": style,
"user": user,
"request_id": request_id,
},
options=make_user_request_input(
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,

View File

@@ -11,7 +11,7 @@ from tenacity import retry
from tenacity.stop import stop_after_attempt
from . import _errors
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import make_httpx_files
from ._request_opt import ClientRequestParam, UserRequestInput
@@ -358,7 +358,6 @@ def make_user_request_input(
max_retries: int | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
extra_headers: Headers = None,
extra_body: Body | None = None,
query: Query | None = None,
) -> UserRequestInput:
options: UserRequestInput = {}
@@ -371,7 +370,5 @@ def make_user_request_input(
options['timeout'] = timeout
if query is not None:
options["params"] = query
if extra_body is not None:
options["extra_json"] = cast(AnyMapping, extra_body)
return options

View File

@@ -1,6 +1,7 @@
from typing import Any
from configs import dify_config
from flask import current_app
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
@@ -13,8 +14,8 @@ class Keyword:
self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword:
config = dify_config
keyword_type = config.KEYWORD_STORE
config = current_app.config
keyword_type = config.get('KEYWORD_STORE')
if not keyword_type:
raise ValueError("Keyword store must be specified.")

View File

@@ -11,7 +11,7 @@ from extensions.ext_database import db
from models.dataset import Dataset
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
@@ -86,7 +86,7 @@ class RetrievalService:
exception_message = ';\n'.join(exceptions)
raise Exception(exception_message)
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
if retrival_method == RetrievalMethod.HYBRID_SEARCH:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke(
query=query,
@@ -142,7 +142,7 @@ class RetrievalService:
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,
@@ -174,7 +174,7 @@ class RetrievalService:
top_k=top_k
)
if documents:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents.extend(data_post_processor.invoke(
query=query,

View File

@@ -7,8 +7,8 @@ _import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from flask import current_app
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, collection_name: str, config: AnalyticdbConfig):
# collection_name must be updated every time
self._collection_name = collection_name.lower()
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
)
documents.append(doc)
return documents
def delete(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
@@ -316,18 +316,17 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
# TODO handle optional params
config = current_app.config
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
access_key_id=config.get("ANALYTICDB_KEY_ID"),
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
region_id=config.get("ANALYTICDB_REGION_ID"),
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
account=config.get("ANALYTICDB_ACCOUNT"),
account_password=config.get("ANALYTICDB_PASSWORD"),
namespace=config.get("ANALYTICDB_NAMESPACE"),
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
),
)
)

View File

@@ -3,9 +3,9 @@ from typing import Any, Optional
import chromadb
from chromadb import QueryResult, Settings
from flask import current_app
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -111,8 +111,7 @@ class ChromaVector(BaseVector):
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -134,14 +133,15 @@ class ChromaVectorFactory(AbstractVectorFactory):
}
dataset.index_struct = json.dumps(index_struct_dict)
config = current_app.config
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=dify_config.CHROMA_HOST,
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
host=config.get('CHROMA_HOST'),
port=int(config.get('CHROMA_PORT')),
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
),
)

View File

@@ -3,10 +3,10 @@ import logging
from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -275,14 +275,15 @@ class MilvusVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
config = current_app.config
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=dify_config.MILVUS_HOST,
port=dify_config.MILVUS_PORT,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
secure=dify_config.MILVUS_SECURE,
database=dify_config.MILVUS_DATABASE,
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
)
)

View File

@@ -5,9 +5,9 @@ from enum import Enum
from typing import Any
from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -156,14 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
config = current_app.config
return MyScaleVector(
collection_name=collection_name,
config=MyScaleConfig(
host=dify_config.MYSCALE_HOST,
port=dify_config.MYSCALE_PORT,
user=dify_config.MYSCALE_USER,
password=dify_config.MYSCALE_PASSWORD,
database=dify_config.MYSCALE_DATABASE,
fts_params=dify_config.MYSCALE_FTS_PARAMS,
host=config.get("MYSCALE_HOST", "localhost"),
port=int(config.get("MYSCALE_PORT", 8123)),
user=config.get("MYSCALE_USER", "default"),
password=config.get("MYSCALE_PASSWORD", ""),
database=config.get("MYSCALE_DATABASE", "default"),
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
),
)

View File

@@ -4,11 +4,11 @@ import ssl
from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -257,13 +257,14 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
config = current_app.config
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
host=config.get('OPENSEARCH_HOST'),
port=config.get('OPENSEARCH_PORT'),
user=config.get('OPENSEARCH_USER'),
password=config.get('OPENSEARCH_PASSWORD'),
secure=config.get('OPENSEARCH_SECURE'),
)
return OpenSearchVector(

View File

@@ -6,9 +6,9 @@ from typing import Any
import numpy
import oracledb
from flask import current_app
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id varchar2(100)
id varchar2(100)
,text CLOB NOT NULL
,meta JSON
,embedding vector NOT NULL
)
)
"""
@@ -219,13 +219,14 @@ class OracleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
config = current_app.config
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
host=dify_config.ORACLE_HOST,
port=dify_config.ORACLE_PORT,
user=dify_config.ORACLE_USER,
password=dify_config.ORACLE_PASSWORD,
database=dify_config.ORACLE_DATABASE,
host=config.get("ORACLE_HOST"),
port=config.get("ORACLE_PORT"),
user=config.get("ORACLE_USER"),
password=config.get("ORACLE_PASSWORD"),
database=config.get("ORACLE_DATABASE"),
),
)

View File

@@ -3,6 +3,7 @@ import logging
from typing import Any
from uuid import UUID, uuid4
from flask import current_app
from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, model_validator
@@ -11,7 +12,6 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
text TEXT NOT NULL,
meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL
) using heap;
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
config = current_app.config
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=dify_config.PGVECTO_RS_HOST,
port=dify_config.PGVECTO_RS_PORT,
user=dify_config.PGVECTO_RS_USER,
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
),
dim=dim
)
)

View File

@@ -5,9 +5,9 @@ from typing import Any
import psycopg2.extras
import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
"""
@@ -185,13 +185,14 @@ class PGVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
config = current_app.config
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=dify_config.PGVECTOR_HOST,
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER,
password=dify_config.PGVECTOR_PASSWORD,
database=dify_config.PGVECTOR_DATABASE,
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
),
)
)

View File

@@ -19,7 +19,6 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -362,8 +361,6 @@ class QdrantVector(BaseVector):
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -447,11 +444,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=config.root_path,
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
)
)

View File

@@ -2,6 +2,7 @@ import json
import uuid
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
@@ -18,7 +19,6 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
document TEXT NOT NULL,
metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
@@ -313,14 +313,15 @@ class RelytVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
config = current_app.config
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=dify_config.RELYT_HOST,
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
),
group_id=dataset.id
)

View File

@@ -1,13 +1,13 @@
import json
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -212,15 +212,16 @@ class TencentVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
config = current_app.config
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=dify_config.TENCENT_VECTOR_DB_URL,
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
)
)
)

View File

@@ -3,12 +3,12 @@ import logging
from typing import Any
import sqlalchemy
from flask import current_app
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
with Session(self._engine) as session:
select_statement = sql_text(
f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance
LIMIT {top_k}
) t WHERE distance < {distance};"""
@@ -234,14 +234,15 @@ class TiDBVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=dify_config.TIDB_VECTOR_HOST,
port=dify_config.TIDB_VECTOR_PORT,
user=dify_config.TIDB_VECTOR_USER,
password=dify_config.TIDB_VECTOR_PASSWORD,
database=dify_config.TIDB_VECTOR_DATABASE,
program_name=dify_config.APPLICATION_NAME,
host=config.get('TIDB_VECTOR_HOST'),
port=config.get('TIDB_VECTOR_PORT'),
user=config.get('TIDB_VECTOR_USER'),
password=config.get('TIDB_VECTOR_PASSWORD'),
database=config.get('TIDB_VECTOR_DATABASE'),
program_name=config.get('APPLICATION_NAME'),
),
)
)

View File

@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any
from configs import dify_config
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -36,7 +37,8 @@ class Vector:
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
vector_type = dify_config.VECTOR_STORE
config = current_app.config
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']

Some files were not shown because too many files have changed in this diff Show More