mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
Compare commits
1 Commits
fix/extra-
...
fix/index-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea5e8ee7cc |
@@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
|
||||
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
|
||||
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎提供任何反馈以供我们改进。
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
|
||||
|
||||
在许可方面,请花一分钟阅读我们简短的 [许可证和贡献者协议](./LICENSE)。社区还遵守 [行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./LICENSE)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss) 制作了一个很好的 [功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
|
||||
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
|
||||
|
||||
@@ -20,44 +20,45 @@
|
||||
|
||||
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
|
||||
|
||||
| 团队成员 | 工作范围 |
|
||||
| Member | Scope |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | 架构 Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG 流水线设计 |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | 构建 workflow 编排 |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 让我们的前端更易用 |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 开发人员体验, 综合事项联系人 |
|
||||
| [@takatost](https://github.com/takatost) | 产品整体方向和架构 |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
|
||||
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
|
||||
|
||||
事项优先级:
|
||||
How we prioritize:
|
||||
|
||||
| 功能类型 | 优先级 |
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| 被团队成员标记为高优先级的功能 | 高优先级 |
|
||||
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
|
||||
| 非核心功能和小幅改进 | 低优先级 |
|
||||
| 有价值当不紧急 | 未来功能 |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
|
||||
### 其他任何事情(例如bug报告、性能优化、拼写错误更正):
|
||||
* 立即开始编码。
|
||||
|
||||
事项优先级:
|
||||
How we prioritize:
|
||||
|
||||
| Issue 类型 | 优先级 |
|
||||
| Issue Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| 核心功能的 Bugs(例如无法登录、应用无法工作、安全漏洞) | 紧急 |
|
||||
| 非紧急 bugs, 性能提升 | 中等优先级 |
|
||||
| 小幅修复(错别字, 能正常工作但存在误导的 UI) | 低优先级 |
|
||||
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## 安装
|
||||
|
||||
以下是设置 Dify 进行开发的步骤:
|
||||
以下是设置Dify进行开发的步骤:
|
||||
|
||||
### 1. Fork 该仓库
|
||||
### 1. Fork该仓库
|
||||
|
||||
### 2. 克隆仓库
|
||||
|
||||
从终端克隆代码仓库:
|
||||
从终端克隆fork的仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
@@ -75,72 +76,72 @@ Dify 依赖以下工具和库:
|
||||
|
||||
### 4. 安装
|
||||
|
||||
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
||||
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。
|
||||
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。
|
||||
|
||||
### 5. 在浏览器中访问 Dify
|
||||
### 5. 在浏览器中访问Dify
|
||||
|
||||
为了验证您的设置,打开浏览器并访问 [http://localhost:3000](http://localhost:3000)(默认或您自定义的 URL 和端口)。现在您应该看到 Dify 正在运行。
|
||||
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)(默认或您自定义的URL和端口)。现在您应该看到Dify正在运行。
|
||||
|
||||
## 开发
|
||||
|
||||
如果您要添加模型提供程序,请参考 [此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
|
||||
如果您要向 Agent 或 Workflow 添加工具提供程序,请参考 [此指南](./api/core/tools/README.md)。
|
||||
如果您要向Agent或Workflow添加工具提供程序,请参考[此指南](./api/core/tools/README.md)。
|
||||
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是 Dify 后端和前端的简要注释大纲:
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是Dify后端和前端的简要注释大纲:
|
||||
|
||||
### 后端
|
||||
|
||||
Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsprojects.com/en/3.0.x/) 框架。它使用 [SQLAlchemy](https://www.sqlalchemy.org/) 作为 ORM,使用 [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) 作为任务队列。授权逻辑通过 Flask-login 进行处理。
|
||||
Dify的后端使用Python编写,使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM,使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。
|
||||
|
||||
```
|
||||
[api/]
|
||||
├── constants // 用于整个代码库的常量设置。
|
||||
├── controllers // API 路由定义和请求处理逻辑。
|
||||
├── core // 核心应用编排、模型集成和工具。
|
||||
├── docker // Docker 和容器化相关配置。
|
||||
├── events // 事件处理和处理。
|
||||
├── extensions // 与第三方框架/平台的扩展。
|
||||
├── fields // 用于序列化/封装的字段定义。
|
||||
├── libs // 可重用的库和助手。
|
||||
├── migrations // 数据库迁移脚本。
|
||||
├── models // 数据库模型和架构定义。
|
||||
├── services // 指定业务逻辑。
|
||||
├── storage // 私钥存储。
|
||||
├── tasks // 异步任务和后台作业的处理。
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
├── core // Core application orchestration, model integrations, and tools.
|
||||
├── docker // Docker & containerization related configurations.
|
||||
├── events // Event handling and processing
|
||||
├── extensions // Extensions with 3rd party frameworks/platforms.
|
||||
├── fields // field definitions for serialization/marshalling.
|
||||
├── libs // Reusable libraries and helpers.
|
||||
├── migrations // Scripts for database migration.
|
||||
├── models // Database models & schema definitions.
|
||||
├── services // Specifies business logic.
|
||||
├── storage // Private key storage.
|
||||
├── tasks // Handling of async tasks and background jobs.
|
||||
└── tests
|
||||
```
|
||||
|
||||
### 前端
|
||||
|
||||
该网站使用基于 Typescript 的 [Next.js](https://nextjs.org/) 模板进行引导,并使用 [Tailwind CSS](https://tailwindcss.com/) 进行样式设计。[React-i18next](https://react.i18next.com/) 用于国际化。
|
||||
该网站使用基于Typescript的[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。
|
||||
|
||||
```
|
||||
[web/]
|
||||
├── app // 布局、页面和组件
|
||||
│ ├── (commonLayout) // 整个应用通用的布局
|
||||
│ ├── (shareLayout) // 在特定会话中共享的布局
|
||||
│ ├── activate // 激活页面
|
||||
│ ├── components // 页面和布局共享的组件
|
||||
│ ├── install // 安装页面
|
||||
│ ├── signin // 登录页面
|
||||
│ └── styles // 全局共享的样式
|
||||
├── assets // 静态资源
|
||||
├── bin // 构建步骤运行的脚本
|
||||
├── config // 可调整的设置和选项
|
||||
├── context // 应用中不同部分使用的共享上下文
|
||||
├── dictionaries // 语言特定的翻译文件
|
||||
├── docker // 容器配置
|
||||
├── hooks // 可重用的钩子
|
||||
├── i18n // 国际化配置
|
||||
├── models // 描述数据模型和 API 响应的形状
|
||||
├── public // 如 favicon 等元资源
|
||||
├── service // 定义 API 操作的形状
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
|
||||
│ ├── activate // activate page
|
||||
│ ├── components // shared by pages and layouts
|
||||
│ ├── install // install page
|
||||
│ ├── signin // signin page
|
||||
│ └── styles // globally shared styles
|
||||
├── assets // Static assets
|
||||
├── bin // scripts ran at build step
|
||||
├── config // adjustable settings and options
|
||||
├── context // shared contexts used by different portions of the app
|
||||
├── dictionaries // Language-specific translate files
|
||||
├── docker // container configurations
|
||||
├── hooks // Reusable hooks
|
||||
├── i18n // Internationalization configuration
|
||||
├── models // describes data models & shapes of API responses
|
||||
├── public // meta assets like favicon
|
||||
├── service // specifies shapes of API actions
|
||||
├── test
|
||||
├── types // 函数参数和返回值的描述
|
||||
└── utils // 共享的实用函数
|
||||
├── types // descriptions of function params and return values
|
||||
└── utils // Shared utility functions
|
||||
```
|
||||
|
||||
## 提交你的 PR
|
||||
|
||||
@@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
|
||||
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
|
||||
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
|
||||
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
|
||||
|
||||
### 5. ブラウザで dify にアクセスする
|
||||
|
||||
|
||||
@@ -256,7 +256,3 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
@@ -23,7 +23,6 @@ class SecurityConfig(BaseSettings):
|
||||
default=24,
|
||||
)
|
||||
|
||||
|
||||
class AppExecutionConfig(BaseSettings):
|
||||
"""
|
||||
App Execution configs
|
||||
@@ -436,13 +435,6 @@ class ImageFormatConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description='the time of the celery scheduler, default to 1 day',
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -470,6 +462,5 @@ class FeatureConfig(
|
||||
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
@@ -7,32 +8,32 @@ class MyScaleConfig(BaseModel):
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: str = Field(
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
description='MyScale host',
|
||||
default='localhost',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: str = Field(
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
description='MyScale user',
|
||||
default='default',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
description='MyScale password',
|
||||
default='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
description='MyScale database name',
|
||||
default='default',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description='Dify version',
|
||||
default='0.6.14',
|
||||
default='0.6.13',
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -15,7 +15,6 @@ from fields.app_fields import (
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
@@ -98,42 +97,8 @@ class AppImportApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=args['data'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
class AppImportFromUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@cloud_edition_billing_resource_check('apps')
|
||||
def post(self):
|
||||
"""Import app from url"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('url', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('name', type=str, location='json')
|
||||
parser.add_argument('description', type=str, location='json')
|
||||
parser.add_argument('icon', type=str, location='json')
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
app = AppDslService.import_and_create_new_app_from_url(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
url=args['url'],
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.import_app(current_user.current_tenant_id, args['data'], args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
@@ -212,13 +177,9 @@ class AppCopyApi(Resource):
|
||||
parser.add_argument('icon_background', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
data = AppDslService.export_dsl(app_model=app_model)
|
||||
app = AppDslService.import_and_create_new_app(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data=data,
|
||||
args=args,
|
||||
account=current_user
|
||||
)
|
||||
app_service = AppService()
|
||||
data = app_service.export_app(app_model)
|
||||
app = app_service.import_app(current_user.current_tenant_id, data, args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
@@ -234,8 +195,10 @@ class AppExportApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(app_model=app_model)
|
||||
"data": app_service.export_app(app_model)
|
||||
}
|
||||
|
||||
|
||||
@@ -359,7 +322,6 @@ class AppTraceApi(Resource):
|
||||
|
||||
api.add_resource(AppListApi, '/apps')
|
||||
api.add_resource(AppImportApi, '/apps/import')
|
||||
api.add_resource(AppImportFromUrlApi, '/apps/import/url')
|
||||
api.add_resource(AppApi, '/apps/<uuid:app_id>')
|
||||
api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
|
||||
api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
|
||||
|
||||
@@ -20,7 +20,6 @@ from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
@@ -129,7 +128,8 @@ class DraftWorkflowImportApi(Resource):
|
||||
parser.add_argument('data', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow = AppDslService.import_and_overwrite_workflow(
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.import_draft_workflow(
|
||||
app_model=app_model,
|
||||
data=args['data'],
|
||||
account=current_user
|
||||
|
||||
@@ -545,15 +545,15 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
@@ -569,15 +569,15 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
RetrievalMethod.SEMANTIC_SEARCH
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
RetrievalMethod.SEMANTIC_SEARCH,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH,
|
||||
RetrievalMethod.HYBRID_SEARCH,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
|
||||
@@ -349,7 +349,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in ['completed', 'error']:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
indexing_runner.calculate_tokens(document)
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
|
||||
@@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||
last_segment = DocumentSegment.query.get(str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
|
||||
@@ -78,12 +78,10 @@ class ChatTextApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@@ -97,8 +95,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
voice=voice,
|
||||
text=text
|
||||
voice=voice
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
||||
@@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
member = db.session.get(Account, str(member_id))
|
||||
member = Account.query.get(str(member_id))
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, request
|
||||
from flask import abort, current_app, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
@@ -13,12 +12,12 @@ from models.model import EndUser
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
if not current_app.config['INNER_API']:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@@ -29,7 +28,7 @@ def inner_api_only(view):
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
if not current_app.config['INNER_API']:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@@ -78,7 +78,7 @@ class AppParameterApi(Resource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -76,12 +76,10 @@ class TextApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@@ -89,15 +87,15 @@ class TextApi(Resource):
|
||||
voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
|
||||
else:
|
||||
try:
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
|
||||
voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
|
||||
'voice')
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice,
|
||||
text=text
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@@ -21,43 +21,14 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunApi(Resource):
|
||||
workflow_run_fields = {
|
||||
'id': fields.String,
|
||||
'workflow_id': fields.String,
|
||||
'status': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'outputs': fields.Raw,
|
||||
'error': fields.String,
|
||||
'total_steps': fields.Integer,
|
||||
'total_tokens': fields.Integer,
|
||||
'created_at': fields.DateTime,
|
||||
'finished_at': fields.DateTime,
|
||||
'elapsed_time': fields.Float,
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_run_fields)
|
||||
def get(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Get a workflow task running detail
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||
return workflow_run
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
@@ -117,5 +88,5 @@ class WorkflowTaskStopApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import api
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class IndexApi(Resource):
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
"server_version": dify_config.CURRENT_VERSION,
|
||||
"server_version": current_app.config['CURRENT_VERSION']
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@@ -75,7 +75,7 @@ class AppParameterApi(WebApiResource):
|
||||
"transfer_methods": ["remote_url", "local_file"]
|
||||
}}),
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,12 +74,10 @@ class TextApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('message_id', type=str, required=False, location='json')
|
||||
parser.add_argument('voice', type=str, location='json')
|
||||
parser.add_argument('text', type=str, location='json')
|
||||
parser.add_argument('streaming', type=bool, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get('message_id', None)
|
||||
text = args.get('text', None)
|
||||
message_id = args.get('message_id')
|
||||
if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict):
|
||||
@@ -96,8 +94,7 @@ class TextApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
end_user=end_user.external_user_id,
|
||||
voice=voice,
|
||||
text=text
|
||||
voice=voice
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
@@ -84,7 +84,7 @@ class AppSiteInfo:
|
||||
self.can_replace_logo = can_replace_logo
|
||||
|
||||
if can_replace_logo:
|
||||
base_url = dify_config.FILES_URL
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
self.custom_config = {
|
||||
|
||||
@@ -255,12 +255,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=stopped_by),
|
||||
|
||||
@@ -214,6 +214,61 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
def calculate_tokens(self, tenant_id: str, tokens: int, dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
if dataset.embedding_model_provider:
|
||||
embedding_model_instance = self.model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
else:
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
||||
|
||||
def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
|
||||
@@ -64,7 +64,6 @@ User Input:
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
@@ -103,7 +103,7 @@ class TokenBufferMemory:
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
|
||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
|
||||
@@ -27,9 +27,9 @@ parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
|
||||
@@ -113,11 +113,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
self._transform_tool_prompt(tool) for tool in tools
|
||||
@@ -126,7 +121,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@@ -136,7 +130,6 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@@ -145,7 +138,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
|
||||
@@ -71,9 +71,6 @@ model_credential_schema:
|
||||
- label:
|
||||
en_US: '2024-02-01'
|
||||
value: '2024-02-01'
|
||||
- label:
|
||||
en_US: '2024-06-01'
|
||||
value: '2024-06-01'
|
||||
placeholder:
|
||||
zh_Hans: 在此选择您的 API 版本
|
||||
en_US: Select your API Version here
|
||||
|
||||
@@ -66,10 +66,6 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: eu-west-2
|
||||
label:
|
||||
en_US: Eu west London (London)
|
||||
zh_Hans: 欧洲西部 (伦敦)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
|
||||
@@ -48,28 +48,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
|
||||
CONVERSE_API_ENABLED_MODEL_INFO=[
|
||||
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
||||
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _find_model_info(model_id):
|
||||
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
|
||||
if model_id.startswith(model['prefix']):
|
||||
return model
|
||||
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||
return None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
@@ -88,12 +66,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
model_info= BedrockLargeLanguageModel._find_model_info(model)
|
||||
if model_info:
|
||||
model_info['model'] = model
|
||||
# invoke models via boto3 converse API
|
||||
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
# TODO: consolidate different invocation methods for models based on base model capabilities
|
||||
# invoke anthropic models via boto3 client
|
||||
if "anthropic" in model:
|
||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
# invoke Cohere models via boto3 client
|
||||
if "cohere.command-r" in model:
|
||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||
@@ -175,12 +151,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
|
||||
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model with converse API
|
||||
Invoke Anthropic large language model
|
||||
|
||||
:param model_info: model information
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
@@ -197,24 +173,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||
|
||||
parameters = {
|
||||
'modelId': model_info['model'],
|
||||
'modelId': model,
|
||||
'messages': prompt_message_dicts,
|
||||
'inferenceConfig': inference_config,
|
||||
'additionalModelRequestFields': additional_model_fields,
|
||||
}
|
||||
|
||||
if model_info['support_system_prompts'] and system and len(system) > 0:
|
||||
if system and len(system) > 0:
|
||||
parameters['system'] = system
|
||||
|
||||
if model_info['support_tool_use'] and tools:
|
||||
if tools:
|
||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||
|
||||
if stream:
|
||||
response = bedrock_client.converse_stream(**parameters)
|
||||
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
|
||||
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = bedrock_client.converse(**parameters)
|
||||
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
|
||||
return self._handle_converse_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
@@ -227,30 +203,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: full response chunk generator result
|
||||
"""
|
||||
response_content = response['output']['message']['content']
|
||||
# transform assistant message to prompt message
|
||||
if response['stopReason'] == 'tool_use':
|
||||
tool_calls = []
|
||||
text, tool_use = self._extract_tool_use(response_content)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_use['toolUseId'],
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_use['name'],
|
||||
arguments=json.dumps(tool_use['input'])
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
else:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response_content[0]['text']
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response['output']['message']['content'][0]['text']
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
if response['usage']:
|
||||
@@ -273,18 +229,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
return result
|
||||
|
||||
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
|
||||
tool_use = {}
|
||||
text = ''
|
||||
for item in content:
|
||||
if 'toolUse' in item:
|
||||
tool_use = item['toolUse']
|
||||
elif 'text' in item:
|
||||
text = item['text']
|
||||
else:
|
||||
raise ValueError(f"Got unknown item: {item}")
|
||||
return text, tool_use
|
||||
|
||||
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||
"""
|
||||
@@ -396,12 +340,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
system = []
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
message.content=message.content.strip()
|
||||
system.append({"text": message.content})
|
||||
else:
|
||||
|
||||
prompt_message_dicts = []
|
||||
for message in prompt_messages:
|
||||
if not isinstance(message, SystemPromptMessage):
|
||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||
|
||||
return system, prompt_message_dicts
|
||||
@@ -502,6 +448,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
||||
|
||||
@@ -2,9 +2,6 @@ model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
|
||||
@@ -2,8 +2,6 @@ model: mistral.mistral-small-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Small
|
||||
model_type: llm
|
||||
features:
|
||||
- tool-call
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
|
||||
@@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
model: gpt-4o-mini-2024-07-18
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini-2024-07-18
|
||||
en_US: gpt-4o-mini-2024-07-18
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,44 +0,0 @@
|
||||
model: gpt-4o-mini
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -616,34 +616,30 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif function_calling_type == 'function_call':
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
# in
|
||||
# message.tool_calls]
|
||||
|
||||
function_call = message.tool_calls[0]
|
||||
message_dict["function_call"] = {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
function_calling_type = credentials.get('function_calling_type', 'no_call')
|
||||
if function_calling_type == 'tool_call':
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
elif function_calling_type == 'function_call':
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
# message_dict = {
|
||||
# "role": "tool",
|
||||
# "content": message.content,
|
||||
# "tool_call_id": message.tool_call_id
|
||||
# }
|
||||
message_dict = {
|
||||
"role": "tool" if credentials and credentials.get('function_calling_type', 'no_call') == 'tool_call' else "function",
|
||||
"content": message.content,
|
||||
"name": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
- openai/gpt-4-32k
|
||||
- openai/gpt-3.5-turbo
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
model: openai/gpt-4o-mini
|
||||
label:
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "0.15"
|
||||
output: "0.60"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 9.2 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 9.5 KiB |
@@ -1,238 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=credentials.get('context_length', 2048),
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
|
||||
features = []
|
||||
|
||||
support_function_call = credentials.get('support_function_call', False)
|
||||
if support_function_call:
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
|
||||
support_vision = credentials.get('support_vision', False)
|
||||
if support_vision:
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
context_length = credentials.get('context_length', 2048)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
@@ -1,190 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=docs
|
||||
)
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
candidate_docs = []
|
||||
|
||||
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||
for idx in range(len(scores)):
|
||||
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||
|
||||
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 3
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(candidate_docs):
|
||||
rerank_document = RerankDocument(
|
||||
index=idx,
|
||||
text=result.get('content'),
|
||||
score=result.get('score', -100.0)
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=rerank_documents
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
@@ -1,17 +0,0 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
@@ -1,125 +0,0 @@
|
||||
provider: sagemaker
|
||||
label:
|
||||
zh_Hans: Sagemaker
|
||||
en_US: Sagemaker
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
description:
|
||||
en_US: Customized model on Sagemaker
|
||||
zh_Hans: Sagemaker上的私有化部署的模型
|
||||
background: "#ECE9E3"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy customized model on Sagemaker
|
||||
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||
url:
|
||||
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Access Key
|
||||
zh_Hans: 在此输入您的 Access Key
|
||||
- variable: aws_secret_access_key
|
||||
required: false
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 在此输入您的 Secret Access Key
|
||||
- variable: aws_region
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
- value: cn-north-1
|
||||
label:
|
||||
en_US: AWS Beijing (cn-north-1)
|
||||
zh_Hans: 中国北京 (cn-north-1)
|
||||
- value: cn-northwest-1
|
||||
label:
|
||||
en_US: AWS Ningxia (cn-northwest-1)
|
||||
zh_Hans: 中国宁夏 (cn-northwest-1)
|
||||
@@ -1,214 +0,0 @@
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
BATCH_SIZE = 20
|
||||
CONTEXT_SIZE=8192
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def batch_generator(generator, batch_size):
|
||||
while True:
|
||||
batch = list(itertools.islice(generator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||
response_model = sm_client.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": content_list,
|
||||
"parameters": {},
|
||||
"is_query" : False,
|
||||
"instruction" : ''
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
embeddings = json_obj['embeddings']
|
||||
return embeddings
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
try:
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
|
||||
line = 3
|
||||
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||
|
||||
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||
all_embeddings = []
|
||||
|
||||
line = 4
|
||||
for batch in batches:
|
||||
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
line = 5
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=0 # It's not SAAS API, usage is meaningless
|
||||
)
|
||||
line = 6
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=all_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
print("validate_credentials ok....")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 9.0 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 1.9 KiB |
@@ -1,6 +0,0 @@
|
||||
- step-1-8k
|
||||
- step-1-32k
|
||||
- step-1-128k
|
||||
- step-1-256k
|
||||
- step-1v-8k
|
||||
- step-1v-32k
|
||||
@@ -1,328 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get('function_calling_type') == 'tool_call'
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
use_template='temperature',
|
||||
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
use_template='max_tokens',
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens', 1024)),
|
||||
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
use_template='top_p',
|
||||
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append({
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments
|
||||
}
|
||||
})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif 'text' in choice:
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
content=""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
model: step-1-128k
|
||||
label:
|
||||
zh_Hans: step-1-128k
|
||||
en_US: step-1-128k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 128000
|
||||
pricing:
|
||||
input: '0.04'
|
||||
output: '0.20'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,25 +0,0 @@
|
||||
model: step-1-256k
|
||||
label:
|
||||
zh_Hans: step-1-256k
|
||||
en_US: step-1-256k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 256000
|
||||
pricing:
|
||||
input: '0.095'
|
||||
output: '0.300'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,28 +0,0 @@
|
||||
model: step-1-32k
|
||||
label:
|
||||
zh_Hans: step-1-32k
|
||||
en_US: step-1-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,28 +0,0 @@
|
||||
model: step-1-8k
|
||||
label:
|
||||
zh_Hans: step-1-8k
|
||||
en_US: step-1-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,25 +0,0 @@
|
||||
model: step-1v-32k
|
||||
label:
|
||||
zh_Hans: step-1v-32k
|
||||
en_US: step-1v-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,25 +0,0 @@
|
||||
model: step-1v-8k
|
||||
label:
|
||||
zh_Hans: step-1v-8k
|
||||
en_US: step-1v-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -1,30 +0,0 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StepfunProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='step-1-8k',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@@ -1,81 +0,0 @@
|
||||
provider: stepfun
|
||||
label:
|
||||
zh_Hans: 阶跃星辰
|
||||
en_US: Stepfun
|
||||
description:
|
||||
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
|
||||
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from stepfun
|
||||
zh_Hans: 从 stepfun 获取 API Key
|
||||
url:
|
||||
en_US: https://platform.stepfun.com/interface-key
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '8192'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '8192'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not supported
|
||||
zh_Hans: 不支持
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
||||
@@ -29,7 +29,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
|
||||
|
||||
@@ -35,4 +35,3 @@ parameter_rules:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
model: ernie-4.0-8k-latest
|
||||
model: ernie-4.0-8k-Latest
|
||||
label:
|
||||
en_US: Ernie-4.0-8K-Latest
|
||||
model_type: llm
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
model: ernie-4.0-turbo-8k-preview
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8k-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@@ -1,40 +0,0 @@
|
||||
model: ernie-4.0-turbo-8k
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@@ -28,4 +28,3 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
model: ernie-character-8k-0321
|
||||
label:
|
||||
en_US: ERNIE-Character-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.95
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1.0
|
||||
default: 0.7
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 1024
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
@@ -28,4 +28,3 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -28,4 +28,3 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -97,7 +97,6 @@ class BaiduAccessToken:
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
@@ -138,10 +137,7 @@ class ErnieBotModel:
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
@@ -152,9 +148,7 @@ class ErnieBotModel:
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview'
|
||||
'ernie-4.0-8k'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
|
||||
@@ -453,11 +453,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
api_key = credentials.get('api_key') or "abc"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
api_key=api_key,
|
||||
api_key='abc',
|
||||
max_retries=3,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -44,23 +44,15 @@ class XinferenceRerankModel(RerankModel):
|
||||
docs=[]
|
||||
)
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
api_key = credentials.get('api_key')
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(response['results']):
|
||||
@@ -110,7 +102,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
|
||||
@@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
}
|
||||
|
||||
def _speech2text_invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
@@ -121,24 +121,17 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
||||
:return: text for given audio file
|
||||
"""
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
api_key = credentials.get('api_key')
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
try:
|
||||
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
response_format=response_format,
|
||||
temperature=temperature
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language = language,
|
||||
prompt = prompt,
|
||||
response_format = response_format,
|
||||
temperature = temperature
|
||||
)
|
||||
|
||||
return response["text"]
|
||||
|
||||
|
||||
@@ -43,17 +43,16 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
api_key = credentials.get('api_key')
|
||||
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
raise InvokeServerUnavailableError(e)
|
||||
|
||||
"""
|
||||
for convenience, the response json is like:
|
||||
class Embedding(TypedDict):
|
||||
@@ -107,7 +106,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||
@@ -118,7 +117,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
@@ -152,7 +151,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@@ -187,7 +186,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
|
||||
@@ -32,7 +32,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
|
||||
@@ -46,12 +46,3 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的Model UID
|
||||
en_US: Enter the model uid
|
||||
- variable: api_key
|
||||
label:
|
||||
zh_Hans: API密钥
|
||||
en_US: API key
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的API密钥
|
||||
en_US: Enter the api key
|
||||
|
||||
@@ -20,7 +20,7 @@ class ZhipuaiProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='glm-4',
|
||||
model='chatglm_turbo',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
|
||||
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.image import ImagesResponded
|
||||
|
||||
@@ -28,9 +28,7 @@ class Images(BaseAPI):
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
@@ -48,12 +46,9 @@ class Images(BaseAPI):
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
|
||||
@@ -11,7 +11,7 @@ from tenacity import retry
|
||||
from tenacity.stop import stop_after_attempt
|
||||
|
||||
from . import _errors
|
||||
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
|
||||
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||
from ._files import make_httpx_files
|
||||
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||
@@ -358,7 +358,6 @@ def make_user_request_input(
|
||||
max_retries: int | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers = None,
|
||||
extra_body: Body | None = None,
|
||||
query: Query | None = None,
|
||||
) -> UserRequestInput:
|
||||
options: UserRequestInput = {}
|
||||
@@ -371,7 +370,5 @@ def make_user_request_input(
|
||||
options['timeout'] = timeout
|
||||
if query is not None:
|
||||
options["params"] = query
|
||||
if extra_body is not None:
|
||||
options["extra_json"] = cast(AnyMapping, extra_body)
|
||||
|
||||
return options
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||
from core.rag.models.document import Document
|
||||
@@ -13,8 +14,8 @@ class Keyword:
|
||||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = dify_config
|
||||
keyword_type = config.KEYWORD_STORE
|
||||
config = current_app.config
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
|
||||
if not keyword_type:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
@@ -11,7 +11,7 @@ from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
@@ -86,7 +86,7 @@ class RetrievalService:
|
||||
exception_message = ';\n'.join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
if retrival_method == RetrievalMethod.HYBRID_SEARCH:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
@@ -142,7 +142,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if documents:
|
||||
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
|
||||
if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
@@ -174,7 +174,7 @@ class RetrievalService:
|
||||
top_k=top_k
|
||||
)
|
||||
if documents:
|
||||
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
|
||||
if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents.extend(data_post_processor.invoke(
|
||||
query=query,
|
||||
|
||||
@@ -7,8 +7,8 @@ _import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
# collection_name must be updated every time
|
||||
self._collection_name = collection_name.lower()
|
||||
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
def delete(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
@@ -316,18 +316,17 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
|
||||
# TODO handle optional params
|
||||
config = current_app.config
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
access_key_id=config.get("ANALYTICDB_KEY_ID"),
|
||||
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
|
||||
region_id=config.get("ANALYTICDB_REGION_ID"),
|
||||
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
|
||||
account=config.get("ANALYTICDB_ACCOUNT"),
|
||||
account_password=config.get("ANALYTICDB_PASSWORD"),
|
||||
namespace=config.get("ANALYTICDB_NAMESPACE"),
|
||||
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -3,9 +3,9 @@ from typing import Any, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -111,8 +111,7 @@ class ChromaVector(BaseVector):
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -134,14 +133,15 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
config = current_app.config
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST,
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||
host=config.get('CHROMA_HOST'),
|
||||
port=int(config.get('CHROMA_PORT')),
|
||||
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
|
||||
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
|
||||
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
|
||||
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,10 +3,10 @@ import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -275,14 +275,15 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
host=dify_config.MILVUS_HOST,
|
||||
port=dify_config.MILVUS_PORT,
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
secure=dify_config.MILVUS_SECURE,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
database=config.get('MILVUS_DATABASE'),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -156,14 +156,15 @@ class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
config=MyScaleConfig(
|
||||
host=dify_config.MYSCALE_HOST,
|
||||
port=dify_config.MYSCALE_PORT,
|
||||
user=dify_config.MYSCALE_USER,
|
||||
password=dify_config.MYSCALE_PASSWORD,
|
||||
database=dify_config.MYSCALE_DATABASE,
|
||||
fts_params=dify_config.MYSCALE_FTS_PARAMS,
|
||||
host=config.get("MYSCALE_HOST", "localhost"),
|
||||
port=int(config.get("MYSCALE_PORT", 8123)),
|
||||
user=config.get("MYSCALE_USER", "default"),
|
||||
password=config.get("MYSCALE_PASSWORD", ""),
|
||||
database=config.get("MYSCALE_DATABASE", "default"),
|
||||
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@ import ssl
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -257,13 +257,14 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
host=config.get('OPENSEARCH_HOST'),
|
||||
port=config.get('OPENSEARCH_PORT'),
|
||||
user=config.get('OPENSEARCH_USER'),
|
||||
password=config.get('OPENSEARCH_PASSWORD'),
|
||||
secure=config.get('OPENSEARCH_SECURE'),
|
||||
)
|
||||
|
||||
return OpenSearchVector(
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Any
|
||||
|
||||
import numpy
|
||||
import oracledb
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id varchar2(100)
|
||||
id varchar2(100)
|
||||
,text CLOB NOT NULL
|
||||
,meta JSON
|
||||
,embedding vector NOT NULL
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@@ -219,13 +219,14 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
config=OracleVectorConfig(
|
||||
host=dify_config.ORACLE_HOST,
|
||||
port=dify_config.ORACLE_PORT,
|
||||
user=dify_config.ORACLE_USER,
|
||||
password=dify_config.ORACLE_PASSWORD,
|
||||
database=dify_config.ORACLE_DATABASE,
|
||||
host=config.get("ORACLE_HOST"),
|
||||
port=config.get("ORACLE_PORT"),
|
||||
user=config.get("ORACLE_USER"),
|
||||
password=config.get("ORACLE_PASSWORD"),
|
||||
database=config.get("ORACLE_DATABASE"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from flask import current_app
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -11,7 +12,6 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
vector vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dim = len(embeddings.embed_query("pgvecto_rs"))
|
||||
|
||||
config = current_app.config
|
||||
return PGVectoRS(
|
||||
collection_name=collection_name,
|
||||
config=PgvectoRSConfig(
|
||||
host=dify_config.PGVECTO_RS_HOST,
|
||||
port=dify_config.PGVECTO_RS_PORT,
|
||||
user=dify_config.PGVECTO_RS_USER,
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
host=config.get('PGVECTO_RS_HOST'),
|
||||
port=config.get('PGVECTO_RS_PORT'),
|
||||
user=config.get('PGVECTO_RS_USER'),
|
||||
password=config.get('PGVECTO_RS_PASSWORD'),
|
||||
database=config.get('PGVECTO_RS_DATABASE'),
|
||||
),
|
||||
dim=dim
|
||||
)
|
||||
)
|
||||
@@ -5,9 +5,9 @@ from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
"""
|
||||
|
||||
|
||||
@@ -185,13 +185,14 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
config=PGVectorConfig(
|
||||
host=dify_config.PGVECTOR_HOST,
|
||||
port=dify_config.PGVECTOR_PORT,
|
||||
user=dify_config.PGVECTOR_USER,
|
||||
password=dify_config.PGVECTOR_PASSWORD,
|
||||
database=dify_config.PGVECTOR_DATABASE,
|
||||
host=config.get("PGVECTOR_HOST"),
|
||||
port=config.get("PGVECTOR_PORT"),
|
||||
user=config.get("PGVECTOR_USER"),
|
||||
password=config.get("PGVECTOR_PASSWORD"),
|
||||
database=config.get("PGVECTOR_DATABASE"),
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -19,7 +19,6 @@ from qdrant_client.http.models import (
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -362,8 +361,6 @@ class QdrantVector(BaseVector):
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -447,11 +444,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=config.root_path,
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
||||
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
||||
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
@@ -18,7 +19,6 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
|
||||
document TEXT NOT NULL,
|
||||
metadata JSON NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
@@ -313,14 +313,15 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=dify_config.RELYT_HOST,
|
||||
port=dify_config.RELYT_PORT,
|
||||
user=dify_config.RELYT_USER,
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
host=config.get('RELYT_HOST'),
|
||||
port=config.get('RELYT_PORT'),
|
||||
user=config.get('RELYT_USER'),
|
||||
password=config.get('RELYT_PASSWORD'),
|
||||
database=config.get('RELYT_DATABASE'),
|
||||
),
|
||||
group_id=dataset.id
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model import document, enum
|
||||
from tcvectordb.model import index as vdb_index
|
||||
from tcvectordb.model.document import Filter
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -212,15 +212,16 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
config=TencentConfig(
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL,
|
||||
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
|
||||
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
|
||||
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
|
||||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
url=config.get('TENCENT_VECTOR_DB_URL'),
|
||||
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
|
||||
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
|
||||
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
|
||||
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
|
||||
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
|
||||
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -3,12 +3,12 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text, distance FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
@@ -234,14 +234,15 @@ class TiDBVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
config=TiDBVectorConfig(
|
||||
host=dify_config.TIDB_VECTOR_HOST,
|
||||
port=dify_config.TIDB_VECTOR_PORT,
|
||||
user=dify_config.TIDB_VECTOR_USER,
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD,
|
||||
database=dify_config.TIDB_VECTOR_DATABASE,
|
||||
program_name=dify_config.APPLICATION_NAME,
|
||||
host=config.get('TIDB_VECTOR_HOST'),
|
||||
port=config.get('TIDB_VECTOR_PORT'),
|
||||
user=config.get('TIDB_VECTOR_USER'),
|
||||
password=config.get('TIDB_VECTOR_PASSWORD'),
|
||||
database=config.get('TIDB_VECTOR_DATABASE'),
|
||||
program_name=config.get('APPLICATION_NAME'),
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from flask import current_app
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -36,7 +37,8 @@ class Vector:
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user