mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 14:58:32 +00:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c038040e1b | ||
|
|
21450b8a51 | ||
|
|
5fc1bd026a | ||
|
|
d60f1a5601 | ||
|
|
da83f8403e | ||
|
|
4ff17af5de | ||
|
|
a9d1b4e6d7 | ||
|
|
66612075d2 | ||
|
|
b921c55677 | ||
|
|
bdc5e9ceb0 | ||
|
|
f2b2effc4b | ||
|
|
301e0496ff | ||
|
|
98660e1f97 | ||
|
|
6cf93379b3 | ||
|
|
8639abec97 | ||
|
|
d5361b8d09 | ||
|
|
6bfdfab6f3 | ||
|
|
bec998ab94 | ||
|
|
77636945fb | ||
|
|
fd5c45ae10 | ||
|
|
ad71386adf | ||
|
|
043517717e | ||
|
|
76c52300a2 | ||
|
|
dda32c6880 | ||
|
|
ac4bb5c35f | ||
|
|
a96cae4f44 | ||
|
|
7cb75cb2e7 | ||
|
|
0940084fd2 | ||
|
|
95ad06c8c3 | ||
|
|
3c13c4f3ee | ||
|
|
2fe938b7da | ||
|
|
784da52ea6 | ||
|
|
78524a56ed | ||
|
|
6c614f0c1f | ||
|
|
d42df4ed04 | ||
|
|
6d94126368 |
@@ -1,57 +1,155 @@
|
||||
# 贡献
|
||||
所以你想为 Dify 做贡献 - 这太棒了,我们迫不及待地想看到你的贡献。作为一家人员和资金有限的初创公司,我们有着雄心勃勃的目标,希望设计出最直观的工作流程来构建和管理 LLM 应用程序。社区的任何帮助都是宝贵的。
|
||||
|
||||
感谢您对 [Dify](https://dify.ai) 的兴趣,并希望您能够做出贡献!在开始之前,请先阅读[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)并查看[现有问题](https://github.com/langgenius/dify/issues)。
|
||||
本文档介绍了如何设置开发环境以构建和测试 [Dify](https://dify.ai)。
|
||||
考虑到我们的现状,我们需要灵活快速地交付,但我们也希望确保像你这样的贡献者在贡献过程中获得尽可能顺畅的体验。我们为此编写了这份贡献指南,旨在让你熟悉代码库和我们与贡献者的合作方式,以便你能快速进入有趣的部分。
|
||||
|
||||
### 安装依赖项
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
|
||||
|
||||
您需要在计算机上安装和配置以下依赖项才能构建 [Dify](https://dify.ai):
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./license)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
- [Git](http://git-scm.com/)
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) 版本 3.10.x
|
||||
## 在开始之前
|
||||
|
||||
## 本地开发
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或[创建](https://github.com/langgenius/dify/issues/new/choose)一个新问题。我们将问题分为两类:
|
||||
|
||||
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
|
||||
### 功能请求:
|
||||
|
||||
### Fork存储库
|
||||
* 如果您要提出新的功能请求,请解释所提议的功能的目标,并尽可能提供详细的上下文。[@perzeusss](https://github.com/perzeuss)制作了一个很好的[功能请求助手](https://udify.app/chat/MK2kVSnw1gakVwMX),可以帮助您起草需求。随时尝试一下。
|
||||
|
||||
您需要 fork [Git 仓库](https://github.com/langgenius/dify)。
|
||||
* 如果您想从现有问题中选择一个,请在其下方留下评论表示您的意愿。
|
||||
|
||||
### 克隆存储库
|
||||
相关方向的团队成员将参与其中。如果一切顺利,他们将批准您开始编码。在此之前,请不要开始工作,以免我们提出更改导致您的工作付诸东流。
|
||||
|
||||
克隆您在 GitHub 上 fork 的仓库:
|
||||
根据所提议的功能所属的领域不同,您可能需要与不同的团队成员交流。以下是我们团队成员目前正在从事的各个领域的概述:
|
||||
|
||||
| Member | Scope |
|
||||
| ------------------------------------------------------------ | ---------------------------------------------------- |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | Architecting Agents |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG pipeline design |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | Building workflow orchestrations |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Making our frontend a breeze to use |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Developer experience, points of contact for anything |
|
||||
| [@takatost](https://github.com/takatost) | Overall product direction and architecture |
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Feature Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| High-Priority Features as being labeled by a team member | High Priority |
|
||||
| Popular feature requests from our [community feedback board](https://feedback.dify.ai/) | Medium Priority |
|
||||
| Non-core features and minor enhancements | Low Priority |
|
||||
| Valuable but not immediate | Future-Feature |
|
||||
|
||||
### 其他任何事情(例如bug报告、性能优化、拼写错误更正):
|
||||
* 立即开始编码。
|
||||
|
||||
How we prioritize:
|
||||
|
||||
| Issue Type | Priority |
|
||||
| ------------------------------------------------------------ | --------------- |
|
||||
| Bugs in core functions (cannot login, applications not working, security loopholes) | Critical |
|
||||
| Non-critical bugs, performance boosts | Medium Priority |
|
||||
| Minor fixes (typos, confusing but working UI) | Low Priority |
|
||||
|
||||
|
||||
## 安装
|
||||
|
||||
以下是设置Dify进行开发的步骤:
|
||||
|
||||
### 1. Fork该仓库
|
||||
|
||||
### 2. 克隆仓库
|
||||
|
||||
从终端克隆fork的仓库:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
### 安装后端
|
||||
### 3. 验证依赖项
|
||||
|
||||
要了解如何安装后端应用程序,请参阅[后端 README](api/README.md)。
|
||||
Dify 依赖以下工具和库:
|
||||
|
||||
### 安装前端
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) version 3.10.x
|
||||
|
||||
要了解如何安装前端应用程序,请参阅[前端 README](web/README.md)。
|
||||
### 4. 安装
|
||||
|
||||
### 在浏览器中访问 Dify
|
||||
Dify由后端和前端组成。通过`cd api/`导航到后端目录,然后按照[后端README](api/README.md)进行安装。在另一个终端中,通过`cd web/`导航到前端目录,然后按照[前端README](web/README.md)进行安装。
|
||||
|
||||
最后,您现在可以访问 [http://localhost:3000](http://localhost:3000) 在本地环境中查看 [Dify](https://dify.ai)。
|
||||
查看[安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq)以获取常见问题列表和故障排除步骤。
|
||||
|
||||
## 创建拉取请求
|
||||
### 5. 在浏览器中访问Dify
|
||||
|
||||
在进行更改后,打开一个拉取请求(PR)。提交拉取请求后,Dify 团队/社区的其他人将与您一起审查它。
|
||||
为了验证您的设置,打开浏览器并访问[http://localhost:3000](http://localhost:3000)(默认或您自定义的URL和端口)。现在您应该看到Dify正在运行。
|
||||
|
||||
如果遇到问题,比如合并冲突或不知道如何打开拉取请求,请查看 GitHub 的[拉取请求教程](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests),了解如何解决合并冲突和其他问题。一旦您的 PR 被合并,您将自豪地被列为[贡献者表](https://github.com/langgenius/dify/graphs/contributors)中的一员。
|
||||
## 开发
|
||||
|
||||
## 社区渠道
|
||||
如果您要添加模型提供程序,请参考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
|
||||
遇到困难了吗?有任何问题吗? 加入 [Discord Community Server](https://discord.gg/AhzKf7dNgk),我们将为您提供帮助。
|
||||
如果您要向Agent或Workflow添加工具提供程序,请参考[此指南](./api/core/tools/README.md)。
|
||||
|
||||
### 多语言支持
|
||||
为了帮助您快速了解您的贡献在哪个部分,以下是Dify后端和前端的简要注释大纲:
|
||||
|
||||
需要参与贡献翻译内容,请参阅[前端多语言翻译 README](web/i18n/README_CN.md)。
|
||||
### 后端
|
||||
|
||||
Dify的后端使用Python编写,使用[Flask](https://flask.palletsprojects.com/en/3.0.x/)框架。它使用[SQLAlchemy](https://www.sqlalchemy.org/)作为ORM,使用[Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html)作为任务队列。授权逻辑通过Flask-login进行处理。
|
||||
|
||||
```
|
||||
[api/]
|
||||
├── constants // Constant settings used throughout code base.
|
||||
├── controllers // API route definitions and request handling logic.
|
||||
├── core // Core application orchestration, model integrations, and tools.
|
||||
├── docker // Docker & containerization related configurations.
|
||||
├── events // Event handling and processing
|
||||
├── extensions // Extensions with 3rd party frameworks/platforms.
|
||||
├── fields // field definitions for serialization/marshalling.
|
||||
├── libs // Reusable libraries and helpers.
|
||||
├── migrations // Scripts for database migration.
|
||||
├── models // Database models & schema definitions.
|
||||
├── services // Specifies business logic.
|
||||
├── storage // Private key storage.
|
||||
├── tasks // Handling of async tasks and background jobs.
|
||||
└── tests
|
||||
```
|
||||
|
||||
### 前端
|
||||
|
||||
该网站使用基于Typescript的[Next.js](https://nextjs.org/)模板进行引导,并使用[Tailwind CSS](https://tailwindcss.com/)进行样式设计。[React-i18next](https://react.i18next.com/)用于国际化。
|
||||
|
||||
```
|
||||
[web/]
|
||||
├── app // layouts, pages, and components
|
||||
│ ├── (commonLayout) // common layout used throughout the app
|
||||
│ ├── (shareLayout) // layouts specifically shared across token-specific sessions
|
||||
│ ├── activate // activate page
|
||||
│ ├── components // shared by pages and layouts
|
||||
│ ├── install // install page
|
||||
│ ├── signin // signin page
|
||||
│ └── styles // globally shared styles
|
||||
├── assets // Static assets
|
||||
├── bin // scripts ran at build step
|
||||
├── config // adjustable settings and options
|
||||
├── context // shared contexts used by different portions of the app
|
||||
├── dictionaries // Language-specific translate files
|
||||
├── docker // container configurations
|
||||
├── hooks // Reusable hooks
|
||||
├── i18n // Internationalization configuration
|
||||
├── models // describes data models & shapes of API responses
|
||||
├── public // meta assets like favicon
|
||||
├── service // specifies shapes of API actions
|
||||
├── test
|
||||
├── types // descriptions of function params and return values
|
||||
└── utils // Shared utility functions
|
||||
```
|
||||
|
||||
## 提交你的 PR
|
||||
|
||||
最后,是时候向我们的仓库提交一个拉取请求(PR)了。对于重要的功能,我们首先将它们合并到 `deploy/dev` 分支进行测试,然后再合并到 `main` 分支。如果你遇到合并冲突或者不知道如何提交拉取请求的问题,请查看 [GitHub 的拉取请求教程](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests)。
|
||||
|
||||
就是这样!一旦你的 PR 被合并,你将成为我们 [README](https://github.com/langgenius/dify/blob/main/README.md) 中的贡献者。
|
||||
|
||||
## 获取帮助
|
||||
|
||||
如果你在贡献过程中遇到困难或者有任何问题,可以通过相关的 GitHub 问题提出你的疑问,或者加入我们的 [Discord](https://discord.gg/AhzKf7dNgk) 进行快速交流。
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
# コントリビュート
|
||||
|
||||
[Dify](https://dify.ai) に興味を持ち、貢献したいと思うようになったことに感謝します!始める前に、
|
||||
[行動規範](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)を読み、
|
||||
[既存の問題](https://github.com/langgenius/langgenius-gateway/issues)をチェックしてください。
|
||||
本ドキュメントは、[Dify](https://dify.ai) をビルドしてテストするための開発環境の構築方法を説明するものです。
|
||||
|
||||
### 依存関係のインストール
|
||||
|
||||
[Dify](https://dify.ai)をビルドするには、お使いのマシンに以下の依存関係をインストールし、設定する必要があります:
|
||||
|
||||
- [Git](http://git-scm.com/)
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [npm](https://www.npmjs.com/) バージョン 8.x.x もしくは [Yarn](https://yarnpkg.com/)
|
||||
- [Python](https://www.python.org/) バージョン 3.10.x
|
||||
|
||||
## ローカル開発
|
||||
|
||||
開発環境を構築するには、プロジェクトの git リポジトリをフォークし、適切なパッケージマネージャを使用してバックエンドとフロントエンドの依存関係をインストールし、docker-compose スタックを実行するように作成します。
|
||||
|
||||
### リポジトリのフォーク
|
||||
|
||||
[リポジトリ](https://github.com/langgenius/dify) をフォークする必要があります。
|
||||
|
||||
### リポジトリのクローン
|
||||
|
||||
GitHub でフォークしたリポジトリのクローンを作成する:
|
||||
|
||||
```
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
### バックエンドのインストール
|
||||
|
||||
バックエンドアプリケーションのインストール方法については、[Backend README](api/README.md) を参照してください。
|
||||
|
||||
### フロントエンドのインストール
|
||||
|
||||
フロントエンドアプリケーションのインストール方法については、[Frontend README](web/README.md) を参照してください。
|
||||
|
||||
### ブラウザで dify にアクセス
|
||||
|
||||
[Dify](https://dify.ai) をローカル環境で見ることができるようになりました [http://localhost:3000](http://localhost:3000)。
|
||||
|
||||
## プルリクエストの作成
|
||||
|
||||
変更後、プルリクエスト (PR) をオープンしてください。プルリクエストを提出すると、Dify チーム/コミュニティの他の人があなたと一緒にそれをレビューします。
|
||||
|
||||
マージコンフリクトなどの問題が発生したり、プルリクエストの開き方がわからなくなったりしませんでしたか? [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) で、マージコンフリクトやその他の問題を解決する方法をチェックしてみてください。あなたの PR がマージされると、[コントリビュータチャート](https://github.com/langgenius/langgenius-gateway/graphs/contributors)にコントリビュータとして誇らしげに掲載されます。
|
||||
|
||||
## コミュニティチャンネル
|
||||
|
||||
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!
|
||||
@@ -21,6 +21,11 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
**Dify** is an LLM application development platform that has helped built over **100,000** applications. It integrates BaaS and LLMOps, covering the essential tech stack for building generative AI-native applications, including a built-in RAG engine. Dify allows you to **deploy your own version of Assistants API and GPTs, based on any LLMs.**
|
||||
|
||||
@@ -55,7 +60,8 @@ You can try out [Dify.AI Cloud](https://dify.ai) now. It provides all the capabi
|
||||
|
||||
**3. RAG Engine**: Includes various RAG capabilities based on full-text indexing or vector database embeddings, allowing direct upload of PDFs, TXTs, and other text formats.
|
||||
|
||||
**4. Agents**: A Function Calling based Agent framework that allows users to configure what they see is what they get. Dify includes basic plugin capabilities like Google Search.
|
||||
**4. AI Agent**: Based on Function Calling and ReAct, the Agent inference framework allows users to customize tools, what you see is what you get. Dify provides more than a dozen built-in tool calling capabilities, such as Google Search, DELL·E, Stable Diffusion, WolframAlpha, etc.
|
||||
|
||||
|
||||
**5. Continuous Operations**: Monitor and analyze application logs and performance, continuously improving Prompts, datasets, or models using production data.
|
||||
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://mp.weixin.qq.com/s/TnyfIuH-tPi9o1KNjwVArw" target="_blank">
|
||||
Dify 发布 AI Agent 能力:基于不同的大型语言模型构建 GPTs 和 Assistants
|
||||
</a>
|
||||
</p>
|
||||
|
||||
Dify 是一个 LLM 应用开发平台,已经有超过 10 万个应用基于 Dify.AI 构建。它融合了 Backend as Service 和 LLMOps 的理念,涵盖了构建生成式 AI 原生应用所需的核心技术栈,包括一个内置 RAG 引擎。使用 Dify,你可以基于任何模型自部署类似 Assistants API 和 GPTs 的能力。
|
||||
|
||||

|
||||
@@ -53,7 +59,7 @@ Dify 具有模型中立性,相较 LangChain 等硬编码开发库 Dify 是一
|
||||
|
||||
**3. RAG引擎**:包括各种基于全文索引或向量数据库嵌入的 RAG 能力,允许直接上传 PDF、TXT 等各种文本格式。
|
||||
|
||||
**4. Agent**:基于函数调用的 Agent框架,允许用户自定义配置,所见即所得。Dify 提供了基本的插件能力,如谷歌搜索。
|
||||
**4. AI Agent**:基于 Function Calling 和 ReAct 的 Agent 推理框架,允许用户自定义工具,所见即所得。Dify 提供了十多种内置工具调用能力,如谷歌搜索、DELL·E、Stable Diffusion、WolframAlpha 等。
|
||||
|
||||
**5. 持续运营**:监控和分析应用日志和性能,使用生产数据持续改进 Prompt、数据集或模型。
|
||||
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
**Dify** es una plataforma de desarrollo de aplicaciones para modelos de lenguaje de gran tamaño (LLM) que ya ha visto la creación de más de **100,000** aplicaciones basadas en Dify.AI. Integra los conceptos de Backend como Servicio y LLMOps, cubriendo el conjunto de tecnologías esenciales requerido para construir aplicaciones nativas de inteligencia artificial generativa, incluyendo un motor RAG incorporado. Con Dify, **puedes auto-desplegar capacidades similares a las de Assistants API y GPTs basadas en cualquier LLM.**
|
||||
|
||||

|
||||
@@ -52,7 +58,7 @@ Dify se caracteriza por su neutralidad de modelo y es un conjunto tecnológico c
|
||||
|
||||
**3. Motor RAG**: Incluye varias capacidades RAG basadas en indexación de texto completo o incrustaciones de base de datos vectoriales, permitiendo la carga directa de PDFs, TXTs y otros formatos de texto.
|
||||
|
||||
**4. Agentes**: Un marco de Agentes basado en Llamadas de Función que permite a los usuarios configurar lo que ven es lo que obtienen. Dify incluye capacidades básicas de plugins como la Búsqueda de Google.
|
||||
**4. Agente de IA**: Basado en la llamada de funciones y ReAct, el marco de inferencia del Agente permite a los usuarios personalizar las herramientas, lo que ves es lo que obtienes. Dify proporciona más de una docena de capacidades de llamada de herramientas incorporadas, como Búsqueda de Google, DELL·E, Difusión Estable, WolframAlpha, etc.
|
||||
|
||||
**5. Operaciones Continuas**: Monitorear y analizar registros de aplicaciones y rendimiento, mejorando continuamente Prompts, conjuntos de datos o modelos usando datos de producción.
|
||||
|
||||
|
||||
@@ -21,6 +21,13 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
**Dify** est une plateforme de développement d'applications LLM qui a déjà vu plus de **100,000** applications construites sur Dify.AI. Elle intègre les concepts de Backend as a Service et LLMOps, couvrant la pile technologique de base requise pour construire des applications natives d'IA générative, y compris un moteur RAG intégré. Avec Dify, **vous pouvez auto-déployer des capacités similaires aux API Assistants et GPT basées sur n'importe quels LLM.**
|
||||
|
||||

|
||||
@@ -52,7 +59,7 @@ Dify présente une neutralité de modèle et est une pile technologique complèt
|
||||
|
||||
**3\. Moteur RAG**: Comprend diverses capacités RAG basées sur l'indexation de texte intégral ou les embeddings de base de données vectorielles, permettant le chargement direct de PDF, TXT et autres formats de texte.
|
||||
|
||||
**4\. Agents**: Un framework d'agents basé sur l'appel de fonctions qui permet aux utilisateurs de configurer ce qu'ils voient est ce qu'ils obtiennent. Dify comprend des capacités de plug-in de base comme Google Search.
|
||||
**4\. AI Agent**: Basé sur l'appel de fonction et ReAct, le framework d'inférence de l'Agent permet aux utilisateurs de personnaliser les outils, ce que vous voyez est ce que vous obtenez. Dify propose plus d'une douzaine de capacités d'appel d'outils intégrées, telles que la recherche Google, DELL·E, Diffusion Stable, WolframAlpha, etc.
|
||||
|
||||
**5\. Opérations continues**: Surveillez et analysez les journaux et les performances des applications, améliorez en continu les invites, les datasets ou les modèles à l'aide de données de production.
|
||||
|
||||
|
||||
@@ -21,6 +21,13 @@
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
"Difyは、既にDify.AI上で10万以上のアプリケーションが構築されているLLMアプリケーション開発プラットフォームです。バックエンド・アズ・ア・サービスとLLMOpsの概念を統合し、組み込みのRAGエンジンを含む、生成AIネイティブアプリケーションを構築するためのコアテックスタックをカバーしています。Difyを使用すると、どのLLMに基づいても、Assistants APIやGPTのような機能を自己デプロイすることができます。"
|
||||
|
||||
Please note that translating complex technical terms can sometimes result in slight variations in meaning due to differences in language nuances.
|
||||
@@ -54,7 +61,7 @@ Difyはモデルニュートラルであり、LangChainのようなハードコ
|
||||
|
||||
**3\. RAGエンジン**: フルテキストインデックスまたはベクトルデータベース埋め込みに基づくさまざまなRAG機能を含み、PDF、TXT、その他のテキストフォーマットの直接アップロードを可能にします。
|
||||
|
||||
**4\. エージェント**: ユーザーが sees what they get を設定できる関数呼び出しベースのエージェントフレームワーク。 Difyには、Google検索などの基本的なプラグイン機能が含まれています。
|
||||
**4. AIエージェント**: 関数呼び出しとReActに基づくAgent推論フレームワークにより、ユーザーはツールをカスタマイズすることができます。Difyは、Google検索、DELL·E、Stable Diffusion、WolframAlphaなど、十数種類の組み込みツール呼び出し機能を提供しています。
|
||||
|
||||
**5\. 継続的運用**: アプリケーションログとパフォーマンスを監視および分析し、運用データを使用してプロンプト、データセット、またはモデルを継続的に改善します。
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ Dify Daq rIn neutrality 'ej Hoch, LangChain tInHar HubwI'. maH Daqbe'law' Qawqar
|
||||
|
||||
**3. RAG Engine**: RAG vaD tIqpu' lo'taH indexing qor neH vector database wa' embeddings wIj, PDFs, TXTs, 'ej ghojmoHmoH HIq qorlIj je upload.
|
||||
|
||||
**4. jenSuvpu'**: jenbe' SuDqang naQ moDwu' jenSuvpu' porgh cha'logh choHvam. Dify Google Search Hur vItlhutlh plugin choH.
|
||||
**4. AI Agent**: Function Calling 'ej ReAct Daq Hurmey, Agent inference framework Hoch users customize tools, vaj 'oH QaQ. Dify Hoch loS ghaH 'ej wa'vatlh built-in tool calling capabilities, Google Search, DELL·E, Stable Diffusion, WolframAlpha, 'ej.
|
||||
|
||||
**5. QaS muDHa'wI': cha'logh wa' pIq mI' logs 'ej quv yIn, vItlhutlh tIq 'e'wIj lo'taHmoHmoH Prompts, vItlhutlh, Hurmey ghaH production data jatlh.
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
# packages install stage
|
||||
FROM python:3.10-slim AS base
|
||||
# base image
|
||||
FROM python:3.10-slim-bookworm AS base
|
||||
|
||||
LABEL maintainer="takatost@gmail.com"
|
||||
|
||||
# install packages
|
||||
FROM base as packages
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
COPY requirements.txt /requirements.txt
|
||||
|
||||
RUN pip install --prefix=/pkg -r requirements.txt
|
||||
|
||||
# build stage
|
||||
FROM python:3.10-slim AS builder
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
ENV FLASK_APP app.py
|
||||
ENV EDITION SELF_HOSTED
|
||||
@@ -30,11 +32,11 @@ ENV TZ UTC
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends bash curl wget vim nodejs ffmpeg \
|
||||
&& apt-get install -y --no-install-recommends curl wget vim nodejs ffmpeg libgmp-dev libmpfr-dev libmpc-dev \
|
||||
&& apt-get autoremove \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=base /pkg /usr/local
|
||||
COPY --from=packages /pkg /usr/local
|
||||
COPY . /app/api/
|
||||
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
@@ -93,7 +93,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.0"
|
||||
self.CURRENT_VERSION = "0.5.1"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -11,10 +11,13 @@ from .app import (advanced_prompt_template, annotation, app, audio, completion,
|
||||
model_config, site, statistic)
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
from .billing import billing
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
# Import explore controllers
|
||||
from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
# Import operation controllers
|
||||
from .operation import operation
|
||||
|
||||
@@ -34,8 +34,7 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
promot=app_model.app_model_config.pre_prompt
|
||||
file=file
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -32,6 +32,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=None
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
0
api/controllers/console/operation/__init__.py
Normal file
0
api/controllers/console/operation/__init__.py
Normal file
30
api/controllers/console/operation/operation.py
Normal file
30
api/controllers/console/operation/operation.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud
|
||||
from libs.login import login_required
|
||||
from services.operation_service import OperationService
|
||||
|
||||
|
||||
class TenantUtm(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('utm_source', type=str, required=True)
|
||||
parser.add_argument('utm_medium', type=str, required=True)
|
||||
parser.add_argument('utm_campaign', type=str, required=False, default='')
|
||||
parser.add_argument('utm_content', type=str, required=False, default='')
|
||||
parser.add_argument('utm_term', type=str, required=False, default='')
|
||||
args = parser.parse_args()
|
||||
|
||||
return OperationService.record_utm(current_user.current_tenant_id, args)
|
||||
|
||||
|
||||
api.add_resource(TenantUtm, '/operation/utm')
|
||||
@@ -66,6 +66,7 @@ class TextApi(AppApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@@ -73,7 +74,7 @@ class TextApi(AppApiResource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
streaming=False
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
|
||||
@@ -75,7 +75,7 @@ def validate_dataset_token(view=None):
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner', 'admin'])) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
|
||||
@@ -31,6 +31,7 @@ class AudioApi(WebApiResource):
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file,
|
||||
end_user=end_user
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
@@ -281,7 +282,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
@@ -298,6 +299,7 @@ class GenerateTaskPipeline:
|
||||
'thought': agent_thought.thought,
|
||||
'observation': agent_thought.observation,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_labels': agent_thought.tool_labels,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp()),
|
||||
'message_files': agent_thought.files
|
||||
|
||||
@@ -153,8 +153,16 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
original_credentials = json.loads(
|
||||
provider_record.encrypted_config) if provider_record.encrypted_config else {}
|
||||
# fix origin data
|
||||
if provider_record.encrypted_config:
|
||||
if not provider_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {
|
||||
"openai_api_key": provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
original_credentials = json.loads(provider_record.encrypted_config)
|
||||
else:
|
||||
original_credentials = {}
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
|
||||
@@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
@@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
if tool not in labels:
|
||||
tool_label = ToolManager.get_tool_label(tool)
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
|
||||
@@ -298,7 +298,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer,
|
||||
),
|
||||
usage=llm_usage['usage'],
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
|
||||
@@ -655,7 +655,9 @@ class IndexingRunner:
|
||||
else:
|
||||
page_content = page_content
|
||||
document_node.page_content = page_content
|
||||
split_documents.append(document_node)
|
||||
|
||||
if document_node.page_content:
|
||||
split_documents.append(document_node)
|
||||
all_documents.extend(split_documents)
|
||||
# processing qa document
|
||||
if document_form == 'qa_model':
|
||||
|
||||
@@ -13,6 +13,7 @@ This module provides the interface for invoking and authenticating various model
|
||||
- `Text Embedding Model` - Text Embedding, pre-computed tokens capability
|
||||
- `Rerank Model` - Segment Rerank capability
|
||||
- `Speech-to-text Model` - Speech to text capability
|
||||
- `Text-to-speech Model` - Text to speech capability
|
||||
- `Moderation` - Moderation capability
|
||||
|
||||
- Model provider display
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
- `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
|
||||
- `Rerank Model` - 分段 Rerank 能力
|
||||
- `Speech-to-text Model` - 语音转文本能力
|
||||
- `Text-to-speech Model` - 文本转语音能力
|
||||
- `Moderation` - Moderation 能力
|
||||
|
||||
- 模型供应商展示
|
||||
|
||||
@@ -299,9 +299,7 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -331,6 +329,46 @@ Inherit the `__base.speech2text_model.Speech2TextModel` base class and implement
|
||||
|
||||
The string after speech-to-text conversion.
|
||||
|
||||
### Text2speech
|
||||
|
||||
Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement the following interfaces:
|
||||
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- Parameters:
|
||||
|
||||
- `model` (string) Model name
|
||||
|
||||
- `credentials` (object) Credential information
|
||||
|
||||
The parameters of credential information are defined by either the `provider_credential_schema` or `model_credential_schema` in the provider's YAML configuration file. Inputs such as `api_key` are included.
|
||||
|
||||
- `content_text` (string) The text content that needs to be converted
|
||||
|
||||
- `streaming` (bool) Whether to stream output
|
||||
|
||||
- `user` (string) [optional] Unique identifier of the user
|
||||
|
||||
This can help the provider monitor and detect abusive behavior.
|
||||
|
||||
- Returns:
|
||||
|
||||
Text converted speech stream。
|
||||
|
||||
### Moderation
|
||||
|
||||
Inherit the `__base.moderation_model.ModerationModel` base class and implement the following interfaces:
|
||||
|
||||
@@ -94,6 +94,7 @@ The currently supported model types are as follows:
|
||||
- `text_embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `tts` Text to speech
|
||||
- `moderation` Moderation
|
||||
|
||||
Continuing with `Anthropic` as an example, since `Anthropic` only supports LLM, we create a `module` named `llm` in `model_providers.anthropic`.
|
||||
|
||||
@@ -47,6 +47,10 @@
|
||||
- `max_chunks` (int) Maximum number of chunks (available for model types `text-embedding`, `moderation`)
|
||||
- `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`)
|
||||
- `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`)
|
||||
- `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] Model invocation parameter rules
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] Pricing information
|
||||
@@ -58,6 +62,7 @@
|
||||
- `text-embedding` Text Embedding model
|
||||
- `rerank` Rerank model
|
||||
- `speech2text` Speech to text
|
||||
- `tts` Text to speech
|
||||
- `moderation` Moderation
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
`Xinference`支持`LLM`和`Text Embedding`和Rerank,那么我们开始编写`xinference.yaml`。
|
||||
|
||||
@@ -369,6 +369,46 @@ class XinferenceProvider(Provider):
|
||||
|
||||
语音转换后的字符串。
|
||||
|
||||
### Text2speech
|
||||
|
||||
继承 `__base.text2speech_model.Text2SpeechModel` 基类,实现以下接口:
|
||||
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: translated audio file
|
||||
"""
|
||||
```
|
||||
|
||||
- 参数:
|
||||
|
||||
- `model` (string) 模型名称
|
||||
|
||||
- `credentials` (object) 凭据信息
|
||||
|
||||
凭据信息的参数由供应商 YAML 配置文件的 `provider_credential_schema` 或 `model_credential_schema` 定义,传入如:`api_key` 等。
|
||||
|
||||
- `content_text` (string) 需要转换的文本内容
|
||||
|
||||
- `streaming` (bool) 是否进行流式输出
|
||||
|
||||
- `user` (string) [optional] 用户的唯一标识符
|
||||
|
||||
可以帮助供应商监控和检测滥用行为。
|
||||
|
||||
- 返回:
|
||||
|
||||
文本转换后的语音流。
|
||||
|
||||
### Moderation
|
||||
|
||||
继承 `__base.moderation_model.ModerationModel` 基类,实现以下接口:
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
- `text_embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
依旧以 `Anthropic` 为例,`Anthropic` 仅支持 LLM,因此在 `model_providers.anthropic` 创建一个 `llm` 为名称的 `module`。
|
||||
|
||||
@@ -48,6 +48,10 @@
|
||||
- `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用)
|
||||
- `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用)
|
||||
- `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用)
|
||||
- `default_voice` (string) 缺省音色,可选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用)
|
||||
- `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用)
|
||||
- `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用)
|
||||
- `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用)
|
||||
- `max_characters_per_chunk` (int) 每块最大字符数 (模型类型 `moderation` 可用)
|
||||
- `parameter_rules` (array[[ParameterRule](#ParameterRule)]) [optional] 模型调用参数规则
|
||||
- `pricing` ([PriceConfig](#PriceConfig)) [optional] 价格信息
|
||||
@@ -59,6 +63,7 @@
|
||||
- `text-embedding` 文本 Embedding 模型
|
||||
- `rerank` Rerank 模型
|
||||
- `speech2text` 语音转文字
|
||||
- `tts` 文字转语音
|
||||
- `moderation` 审查
|
||||
|
||||
### ConfigurateMethod
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import uuid
|
||||
import hashlib
|
||||
import subprocess
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
|
||||
|
||||
class TTSModel(AIModel):
|
||||
@@ -40,3 +45,96 @@ class TTSModel(AIModel):
|
||||
:return: translated audio file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_model_voice(self, model: str, credentials: dict) -> any:
|
||||
"""
|
||||
Get voice for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio max workers for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
try:
|
||||
output = subprocess.check_output("ffmpeg -version", shell=True)
|
||||
if "ffmpeg version" in output.decode("utf-8"):
|
||||
return True
|
||||
else:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
||||
except Exception:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
@staticmethod
|
||||
def _get_file_name(file_content: str) -> str:
|
||||
hash_object = hashlib.sha256(file_content.encode())
|
||||
hex_digest = hash_object.hexdigest()
|
||||
|
||||
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
|
||||
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
|
||||
return str(unique_uuid)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
@@ -76,7 +76,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[""],
|
||||
texts="",
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
@@ -147,7 +147,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
return ai_model_entity.entity
|
||||
|
||||
@staticmethod
|
||||
def _embedding_invoke(model: str, client: AzureOpenAI, texts: list[str],
|
||||
def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str],
|
||||
extra_model_kwargs: dict) -> Tuple[list[list[float]], int]:
|
||||
response = client.embeddings.create(
|
||||
input=texts,
|
||||
|
||||
@@ -76,7 +76,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=[""]
|
||||
texts=[" "]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
@@ -131,6 +131,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
:param text: text to tokenize
|
||||
:return:
|
||||
"""
|
||||
if not text:
|
||||
return Tokens([], [], {})
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ parameter_rules:
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.00'
|
||||
input: '0.015'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@@ -36,7 +36,7 @@ parameter_rules:
|
||||
en_US: Enable Web Search
|
||||
zh_Hans: 开启网页搜索
|
||||
pricing:
|
||||
input: '0.00'
|
||||
input: '0.015'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@@ -29,7 +29,7 @@ parameter_rules:
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.00'
|
||||
input: '0.005'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
model: abab6-chat
|
||||
label:
|
||||
en_US: Abab6-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.9
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -78,7 +78,7 @@ class MinimaxChatCompletion(object):
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=10)
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
except Exception as e:
|
||||
raise InternalServerError(e)
|
||||
|
||||
|
||||
@@ -22,9 +22,6 @@ class MinimaxChatCompletionPro(object):
|
||||
"""
|
||||
generate chat completion
|
||||
"""
|
||||
if model not in ['abab5.5-chat', 'abab5.5s-chat']:
|
||||
raise BadRequestError(f'Invalid model: {model}')
|
||||
|
||||
if not api_key or not group_id:
|
||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
||||
|
||||
@@ -87,7 +84,7 @@ class MinimaxChatCompletionPro(object):
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=10)
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
except Exception as e:
|
||||
raise InternalServerError(e)
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Generator, List, Optional, Union
|
||||
from typing import Generator, List
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||
SystemPromptMessage, UserPromptMessage)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -18,6 +17,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
'abab6-chat': MinimaxChatCompletionPro,
|
||||
'abab5.5s-chat': MinimaxChatCompletionPro,
|
||||
'abab5.5-chat': MinimaxChatCompletionPro,
|
||||
'abab5-chat': MinimaxChatCompletion
|
||||
@@ -55,7 +55,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
stream=False,
|
||||
user=''
|
||||
)
|
||||
except InvalidAuthenticationError as e:
|
||||
except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
|
||||
@@ -27,4 +27,4 @@ class MinimaxProvider(ModelProvider):
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
raise CredentialsValidateFailedError(f'{ex}')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
@@ -89,7 +89,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[""],
|
||||
texts="",
|
||||
extra_model_kwargs=extra_model_kwargs
|
||||
)
|
||||
|
||||
@@ -160,7 +160,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, client: OpenAI, texts: list[str],
|
||||
def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str],
|
||||
extra_model_kwargs: dict) -> Tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
import uuid
|
||||
import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from functools import reduce
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
||||
|
||||
from typing_extensions import Literal
|
||||
from flask import Response, stream_with_context
|
||||
from openai import OpenAI
|
||||
import concurrent.futures
|
||||
@@ -22,9 +17,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
"""
|
||||
Model class for OpenAI Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool,
|
||||
user: Optional[str] = None) -> any:
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
@@ -65,7 +58,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> Response:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
@@ -104,8 +97,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str,
|
||||
user: Optional[str] = None) -> any:
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
@@ -131,84 +123,6 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _get_model_voice(self, model: str, credentials: dict) -> Literal[
|
||||
"alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
|
||||
"""
|
||||
Get voice for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: voice
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.AUDOI_TYPE in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.AUDOI_TYPE]
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio type for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get audio max workers for given tts model
|
||||
:return: audio type
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||
|
||||
@staticmethod
|
||||
def _split_text_into_sentences(text: str, limit: int, delimiters=None):
|
||||
if delimiters is None:
|
||||
delimiters = set('。!?;\n')
|
||||
|
||||
buf = []
|
||||
word_count = 0
|
||||
for char in text:
|
||||
buf.append(char)
|
||||
if char in delimiters:
|
||||
if word_count >= limit:
|
||||
yield ''.join(buf)
|
||||
buf = []
|
||||
word_count = 0
|
||||
else:
|
||||
word_count += 1
|
||||
else:
|
||||
word_count += 1
|
||||
|
||||
if buf:
|
||||
yield ''.join(buf)
|
||||
|
||||
@staticmethod
|
||||
def _get_file_name(file_content: str) -> str:
|
||||
hash_object = hashlib.sha256(file_content.encode())
|
||||
hex_digest = hash_object.hexdigest()
|
||||
|
||||
namespace_uuid = uuid.UUID('a5da6ef9-b303-596f-8e88-bf8fa40f4b31')
|
||||
unique_uuid = uuid.uuid5(namespace_uuid, hex_digest)
|
||||
return str(unique_uuid)
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str, credentials: dict):
|
||||
"""
|
||||
_tts_invoke openai text2speech model api
|
||||
@@ -226,18 +140,3 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
|
||||
response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip())
|
||||
if isinstance(response.read(), bytes):
|
||||
return response.read()
|
||||
|
||||
@staticmethod
|
||||
def _is_ffmpeg_installed():
|
||||
try:
|
||||
output = subprocess.check_output("ffmpeg -version", shell=True)
|
||||
if "ffmpeg version" in output.decode("utf-8"):
|
||||
return True
|
||||
else:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
||||
except Exception:
|
||||
raise InvokeBadRequestError("ffmpeg is not installed, "
|
||||
"details: https://docs.dify.ai/getting-started/install-self-hosted"
|
||||
"/install-faq#id-14.-what-to-do-if-this-error-occurs-in-text-to-speech")
|
||||
|
||||
@@ -224,7 +224,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
|
||||
return entity
|
||||
|
||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
||||
@@ -343,32 +343,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
)
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'):
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
if chunk:
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}")
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
if delta.get('content') is None or delta.get('content') == '':
|
||||
continue
|
||||
if finish_reason is not None:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=choice.get('text', '')),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
@@ -387,24 +399,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
full_assistant_content += delta.get('content', '')
|
||||
elif 'text' in choice:
|
||||
if choice.get('text') is None or choice.get('text') == '':
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=choice.get('text', '')
|
||||
)
|
||||
|
||||
full_assistant_content += choice.get('text', '')
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
if chunk_json['choices'][0].get('finish_reason') is not None:
|
||||
if finish_reason is not None:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk_json['choices'][0]['finish_reason']
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
|
||||
@@ -75,3 +75,12 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
- variable: stream_mode_delimiter
|
||||
label:
|
||||
zh_Hans: 流模式返回结果的分隔符
|
||||
en_US: Delimiter for streaming results
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '\n\n'
|
||||
type: text-input
|
||||
|
||||
23
api/core/model_runtime/model_providers/tongyi/_common.py
Normal file
23
api/core/model_runtime/model_providers/tongyi/_common.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
class _CommonTongyi:
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
credentials_kwargs = {
|
||||
"dashscope_api_key": credentials['dashscope_api_key'],
|
||||
}
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
pass
|
||||
@@ -16,6 +16,7 @@ help:
|
||||
en_US: https://dashscope.console.aliyun.com/api-key_management
|
||||
supported_model_types:
|
||||
- llm
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
||||
12
api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml
Normal file
12
api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
model: tts-1
|
||||
model_type: tts
|
||||
model_properties:
|
||||
default_voice: 'sambert-zhiru-v1' # 音色参考 https://help.aliyun.com/zh/dashscope/model-list 配置
|
||||
word_limit: 120
|
||||
audio_type: 'mp3'
|
||||
max_workers: 5
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '0'
|
||||
unit: '0.0001'
|
||||
currency: RMB
|
||||
142
api/core/model_runtime/model_providers/tongyi/tts/tts.py
Normal file
142
api/core/model_runtime/model_providers/tongyi/tts/tts.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from functools import reduce
|
||||
from pydub import AudioSegment
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.model_runtime.model_providers.tongyi._common import _CommonTongyi
|
||||
|
||||
import dashscope
|
||||
from flask import Response, stream_with_context
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
|
||||
"""
|
||||
Model class for Tongyi Speech to text model.
|
||||
"""
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param streaming: output is streaming
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
self._is_ffmpeg_installed()
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
if streaming:
|
||||
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
user=user)),
|
||||
status=200, mimetype=f'audio/{audio_type}')
|
||||
else:
|
||||
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
validate credentials text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
self._tts_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text='Hello world!',
|
||||
user=user
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> Response:
|
||||
"""
|
||||
_tts_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
max_workers = self._get_model_workers_limit(model, credentials)
|
||||
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
audio_bytes_list = list()
|
||||
|
||||
# Create a thread pool and map the function to the list of sentences
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(self._process_sentence, model=model, sentence=sentence,
|
||||
credentials=credentials, audio_type=audio_type) for sentence in sentences]
|
||||
for future in futures:
|
||||
try:
|
||||
audio_bytes_list.append(future.result())
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||
audio_bytes_list if audio_bytes]
|
||||
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||
buffer: BytesIO = BytesIO()
|
||||
combined_segment.export(buffer, format=audio_type)
|
||||
buffer.seek(0)
|
||||
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
# Todo: To improve the streaming function
|
||||
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
dashscope.api_key = credentials.get('dashscope_api_key')
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
audio_type = self._get_model_audio_type(model, credentials)
|
||||
try:
|
||||
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
|
||||
for sentence in sentences:
|
||||
response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice_name, sample_rate=48000, text=sentence.strip(),
|
||||
format=audio_type, word_timestamp_enabled=True,
|
||||
phoneme_timestamp_enabled=True)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
return response.get_audio_data()
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _process_sentence(self, sentence: str, model: str, credentials: dict, audio_type: str):
|
||||
"""
|
||||
_tts_invoke Tongyi text2speech model api
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param sentence: text content to be translated
|
||||
:param audio_type: audio file type
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
dashscope.api_key = credentials.get('dashscope_api_key')
|
||||
voice_name = self._get_model_voice(model, credentials)
|
||||
|
||||
response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice_name, sample_rate=48000, text=sentence.strip(), format=audio_type)
|
||||
if isinstance(response.get_audio_data(), bytes):
|
||||
return response.get_audio_data()
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Wrapper around ZhipuAI APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import posixpath
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils import jwt_token
|
||||
from zhipuai.utils.http_client import post, stream
|
||||
from zhipuai.utils.sse_client import SSEClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZhipuModelAPI(BaseModel):
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
api_key: str
|
||||
api_timeout_seconds = 60
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
if not response['success']:
|
||||
raise ValueError(
|
||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||
)
|
||||
return response
|
||||
|
||||
def sse_invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
return SSEClient(data)
|
||||
|
||||
def _build_api_url(self, kwargs, *path):
|
||||
if kwargs:
|
||||
if "model" not in kwargs:
|
||||
raise Exception("model param missed")
|
||||
model = kwargs.pop("model")
|
||||
else:
|
||||
model = "-"
|
||||
|
||||
return posixpath.join(self.base_url, model, *path)
|
||||
|
||||
def _generate_token(self):
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"api_key not provided, you could provide it."
|
||||
)
|
||||
|
||||
try:
|
||||
return jwt_token.generate_token(self.api_key)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Your api_key is invalid, please check it."
|
||||
)
|
||||
@@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage,
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
|
||||
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils import helper
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||
|
||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
@@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, tools)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
@@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
model_parameters={
|
||||
"temperature": 0.5,
|
||||
},
|
||||
tools=[],
|
||||
stream=False
|
||||
)
|
||||
except Exception as ex:
|
||||
@@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
def _generate(self, model: str, credentials_kwargs: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
# not support image message
|
||||
continue
|
||||
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
|
||||
copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||
new_prompt_messages.append(new_prompt_message)
|
||||
else:
|
||||
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
||||
new_prompt_messages.append(new_prompt_message)
|
||||
@@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if model == 'glm-4v':
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{
|
||||
'messages': [{
|
||||
'role': prompt_message.role.value,
|
||||
'content':
|
||||
[
|
||||
@@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content,
|
||||
} for prompt_message in new_prompt_messages],
|
||||
'messages': [],
|
||||
**model_parameters
|
||||
}
|
||||
# glm model
|
||||
if not model.startswith('chatglm'):
|
||||
|
||||
for prompt_message in new_prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.TOOL:
|
||||
params['messages'].append({
|
||||
'role': 'tool',
|
||||
'content': prompt_message.content,
|
||||
'tool_call_id': prompt_message.tool_call_id
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content
|
||||
})
|
||||
else:
|
||||
# chatglm model
|
||||
for prompt_message in new_prompt_messages:
|
||||
# merge system message to user message
|
||||
if prompt_message.role == PromptMessageRole.SYSTEM or \
|
||||
prompt_message.role == PromptMessageRole.TOOL or \
|
||||
prompt_message.role == PromptMessageRole.USER:
|
||||
if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
|
||||
params['messages'][-1]['content'] += "\n\n" + prompt_message.content
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': 'user',
|
||||
'content': prompt_message.content
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content
|
||||
})
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
params['tools'] = [
|
||||
{
|
||||
'type': 'function',
|
||||
'function': helper.dump_model(tool)
|
||||
} for tool in tools
|
||||
]
|
||||
|
||||
if stream:
|
||||
response = client.sse_invoke(incremental=True, **params).events()
|
||||
return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
|
||||
response = client.chat.completions.create(stream=stream, **params)
|
||||
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
response = client.invoke(**params)
|
||||
return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
|
||||
response = client.chat.completions.create(**params)
|
||||
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str,
|
||||
credentials: dict,
|
||||
response: Dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
data = response["data"]
|
||||
text = ''
|
||||
for res in data["choices"]:
|
||||
text += res['content']
|
||||
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||
for choice in response.choices:
|
||||
if choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if tool_call.type == 'function':
|
||||
assistant_tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
text += choice.message.content or ''
|
||||
|
||||
token_usage = data.get("usage")
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
prompt_usage = response.usage.prompt_tokens
|
||||
completion_usage = response.usage.completion_tokens
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
||||
usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
message=AssistantPromptMessage(
|
||||
content=text,
|
||||
tool_calls=assistant_tool_calls
|
||||
),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
def _handle_generate_stream_response(self, model: str,
|
||||
credentials: dict,
|
||||
responses: list[Generator],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
responses: Generator[ChatCompletionChunk, None, None],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
for index, event in enumerate(responses):
|
||||
if event.event == "add":
|
||||
full_assistant_content = ''
|
||||
for chunk in responses:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
continue
|
||||
|
||||
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||
for tool_call in delta.delta.tool_calls or []:
|
||||
if tool_call.type == 'function':
|
||||
assistant_tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=assistant_tool_calls
|
||||
)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||
|
||||
if delta.finish_reason is not None and chunk.usage is not None:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
model=model,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=event.data)
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
elif event.event == "error" or event.event == "interrupted":
|
||||
raise ValueError(
|
||||
f"{event.data}"
|
||||
)
|
||||
elif event.event == "finish":
|
||||
meta = json.loads(event.meta)
|
||||
token_usage = meta['usage']
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
||||
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=event.data),
|
||||
finish_reason='finish',
|
||||
usage=usage
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
||||
"""
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
@@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
for message in messages
|
||||
)
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
text += "\n\nTools:"
|
||||
for tool in tools:
|
||||
text += f"\n{tool.json()}"
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
return text.rstrip()
|
||||
@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
|
||||
@@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
||||
def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
response = client.invoke(model=model, prompt=text)
|
||||
data = response["data"]
|
||||
embeddings.append(data.get('embedding'))
|
||||
response = client.embeddings.create(model=model, input=text)
|
||||
data = response.data[0]
|
||||
embeddings.append(data.embedding)
|
||||
embedding_used_tokens += response.usage.total_tokens
|
||||
|
||||
embedding_used_tokens = data.get('usage')
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
from ._client import ZhipuAI
|
||||
|
||||
from .core._errors import (
|
||||
ZhipuAIError,
|
||||
APIStatusError,
|
||||
APIRequestFailedError,
|
||||
APIAuthenticationError,
|
||||
APIReachLimitError,
|
||||
APIInternalError,
|
||||
APIServerFlowExceedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APITimeoutError,
|
||||
)
|
||||
|
||||
from .__version__ import __version__
|
||||
@@ -0,0 +1,2 @@
|
||||
|
||||
__version__ = 'v2.0.1'
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Mapping
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .core import _jwt_token
|
||||
from .core._errors import ZhipuAIError
|
||||
from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
|
||||
from .core._base_type import NotGiven, NOT_GIVEN
|
||||
from . import api_resource
|
||||
import os
|
||||
import httpx
|
||||
from httpx import Timeout
|
||||
|
||||
|
||||
class ZhipuAI(HttpClient):
|
||||
chat: api_resource.chat
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
http_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None
|
||||
) -> None:
|
||||
# if api_key is None:
|
||||
# api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://open.bigmodel.cn/api/paas/v4"
|
||||
from .__version__ import __version__
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
custom_httpx_client=http_client,
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
self.chat = api_resource.chat.Chat(self)
|
||||
self.images = api_resource.images.Images(self)
|
||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||
self.files = api_resource.files.Files(self)
|
||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
|
||||
|
||||
def __del__(self) -> None:
|
||||
if (not hasattr(self, "_has_custom_http_client")
|
||||
or not hasattr(self, "close")
|
||||
or not hasattr(self, "_client")):
|
||||
# if the '__init__' method raised an error, self would not have client attr
|
||||
return
|
||||
|
||||
if self._has_custom_http_client:
|
||||
return
|
||||
|
||||
self.close()
|
||||
@@ -0,0 +1,5 @@
|
||||
from .chat import chat
|
||||
from .images import Images
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files
|
||||
from .fine_tuning import fine_tuning
|
||||
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ...core._http_client import make_user_request_input
|
||||
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class AsyncCompletions(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, List[str], List[int], List[List[int]], None],
|
||||
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AsyncTaskStatus:
|
||||
_cast_type = AsyncTaskStatus
|
||||
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/async/chat/completions",
|
||||
body={
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
||||
|
||||
def retrieve_completion_result(
|
||||
self,
|
||||
id: str,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._get(
|
||||
path=f"/async-result/{id}",
|
||||
cast_type=_cast_type,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from .completions import Completions
|
||||
from .async_completions import AsyncCompletions
|
||||
from ...core._base_api import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
completions: Completions
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.completions = Completions(client)
|
||||
self.asyncCompletions = AsyncCompletions(client)
|
||||
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ...core._http_client import make_user_request_input
|
||||
from ...core._sse_client import StreamResponse
|
||||
from ...types.chat.chat_completion import Completion
|
||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Completions(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, List[str], List[int], object, None],
|
||||
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||
_cast_type = Completion
|
||||
_stream_cls = StreamResponse[ChatCompletionChunk]
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
_stream_cls = StreamResponse[object]
|
||||
return self._post(
|
||||
"/chat/completions",
|
||||
body={
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=stream or False,
|
||||
stream_cls=_stream_cls,
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.embeddings import EmbeddingsResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Embeddings(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, List[str], List[int], List[List[int]]],
|
||||
model: Union[str],
|
||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> EmbeddingsResponded:
|
||||
_cast_type = EmbeddingsResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/embeddings",
|
||||
body={
|
||||
"input": input,
|
||||
"model": model,
|
||||
"encoding_format": encoding_format,
|
||||
"user": user,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
|
||||
from ..core._files import is_file_content
|
||||
from ..core._http_client import (
|
||||
make_user_request_input,
|
||||
)
|
||||
from ..types.file_object import FileObject, ListOfFileObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
__all__ = ["Files"]
|
||||
|
||||
|
||||
class Files(BaseAPI):
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: FileTypes,
|
||||
purpose: str,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileObject:
|
||||
if not is_file_content(file):
|
||||
prefix = f"Expected file input `{file!r}`"
|
||||
raise RuntimeError(
|
||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
|
||||
) from None
|
||||
files = [("file", file)]
|
||||
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
|
||||
return self._post(
|
||||
"/files",
|
||||
body={
|
||||
"purpose": purpose,
|
||||
},
|
||||
files=files,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FileObject,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
order: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFileObject:
|
||||
return self._get(
|
||||
"/files",
|
||||
cast_type=ListOfFileObject,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"purpose": purpose,
|
||||
"limit": limit,
|
||||
"after": after,
|
||||
"order": order,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from .jobs import Jobs
|
||||
from ...core._base_api import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class FineTuning(BaseAPI):
|
||||
jobs: Jobs
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.jobs = Jobs(client)
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||
from ...core._http_client import (
|
||||
make_user_request_input,
|
||||
)
|
||||
from ...types.fine_tuning import (
|
||||
FineTuningJob,
|
||||
job_create_params,
|
||||
ListOfFineTuningJob,
|
||||
FineTuningJobEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
|
||||
|
||||
class Jobs(BaseAPI):
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
||||
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._post(
|
||||
"/fine_tuning/jobs",
|
||||
body={
|
||||
"model": model,
|
||||
"training_file": training_file,
|
||||
"hyperparameters": hyperparameters,
|
||||
"suffix": suffix,
|
||||
"validation_file": validation_file,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFineTuningJob:
|
||||
return self._get(
|
||||
"/fine_tuning/jobs",
|
||||
cast_type=ListOfFineTuningJob,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def list_events(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJobEvent:
|
||||
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
||||
cast_type=FineTuningJobEvent,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.image import ImagesResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Images(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str | NotGiven = NOT_GIVEN,
|
||||
n: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
_cast_type = ImagesResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/images/generations",
|
||||
body={
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"n": n,
|
||||
"quality": quality,
|
||||
"response_format": response_format,
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class BaseAPI:
|
||||
_client: ZhipuAI
|
||||
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
self._client = client
|
||||
self._delete = client.delete
|
||||
self._get = client.get
|
||||
self._post = client.post
|
||||
self._put = client.put
|
||||
self._patch = client.patch
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import PathLike
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
Union,
|
||||
Mapping,
|
||||
TypeVar, IO, Tuple, Sequence, Any, List,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
override,
|
||||
)
|
||||
|
||||
|
||||
Query = Mapping[str, object]
|
||||
Body = object
|
||||
AnyMapping = Mapping[str, object]
|
||||
PrimitiveData = Union[str, int, float, bool, None]
|
||||
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
|
||||
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
NoneType: Type[None]
|
||||
else:
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
# Sentinel class used until PEP 0661 is accepted
|
||||
class NotGiven(pydantic.BaseModel):
|
||||
"""
|
||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||
from those passed in with the value None (which may have different behavior).
|
||||
|
||||
For example:
|
||||
|
||||
```py
|
||||
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
||||
|
||||
get(timeout=1) # 1s timeout
|
||||
get(timeout=None) # No timeout
|
||||
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_GIVEN"
|
||||
|
||||
|
||||
NotGivenOr = Union[_T, NotGiven]
|
||||
NOT_GIVEN = NotGiven()
|
||||
|
||||
|
||||
class Omit(pydantic.BaseModel):
|
||||
"""In certain situations you need to be able to represent a case where a default value has
|
||||
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||
|
||||
```py
|
||||
# as the default `Content-Type` header is `application/json` that will be sent
|
||||
client.post('/upload/files', files={'file': b'my raw file content'})
|
||||
|
||||
# you can't explicitly override the header as it has to be dynamically generated
|
||||
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||
client.post(..., headers={'Content-Type': 'multipart/form-data'})
|
||||
|
||||
# instead you can remove the default `application/json` header by passing Omit
|
||||
client.post(..., headers={'Content-Type': Omit()})
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
Headers = Mapping[str, Union[str, Omit]]
|
||||
|
||||
ResponseT = TypeVar(
|
||||
"ResponseT",
|
||||
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
|
||||
)
|
||||
|
||||
# for user input files
|
||||
if TYPE_CHECKING:
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
else:
|
||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||
|
||||
FileTypes = Union[
|
||||
FileContent, # file content
|
||||
Tuple[str, FileContent], # (filename, file)
|
||||
Tuple[str, FileContent, str], # (filename, file , content_type)
|
||||
Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||
]
|
||||
|
||||
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||
|
||||
# for httpx client supported files
|
||||
|
||||
HttpxFileContent = Union[bytes, IO[bytes]]
|
||||
HttpxFileTypes = Union[
|
||||
FileContent, # file content
|
||||
Tuple[str, HttpxFileContent], # (filename, file)
|
||||
Tuple[str, HttpxFileContent, str], # (filename, file , content_type)
|
||||
Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||
]
|
||||
|
||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
__all__ = [
|
||||
"ZhipuAIError",
|
||||
"APIStatusError",
|
||||
"APIRequestFailedError",
|
||||
"APIAuthenticationError",
|
||||
"APIReachLimitError",
|
||||
"APIInternalError",
|
||||
"APIServerFlowExceedError",
|
||||
"APIResponseError",
|
||||
"APIResponseValidationError",
|
||||
"APITimeoutError",
|
||||
]
|
||||
|
||||
|
||||
class ZhipuAIError(Exception):
|
||||
def __init__(self, message: str, ) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIStatusError(Exception):
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, message: str, *, response: httpx.Response) -> None:
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIRequestFailedError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIAuthenticationError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIReachLimitError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIInternalError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIServerFlowExceedError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIResponseError(Exception):
|
||||
message: str
|
||||
request: httpx.Request
|
||||
json_data: object
|
||||
|
||||
def __init__(self, message: str, request: httpx.Request, json_data: object):
|
||||
self.message = message
|
||||
self.request = request
|
||||
self.json_data = json_data
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIResponseValidationError(APIResponseError):
|
||||
status_code: int
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
json_data: object | None, *,
|
||||
message: str | None = None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message=message or "Data returned by API invalid for expected schema.",
|
||||
request=response.request,
|
||||
json_data=json_data
|
||||
)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APITimeoutError(Exception):
|
||||
request: httpx.Request
|
||||
|
||||
def __init__(self, request: httpx.Request):
|
||||
self.request = request
|
||||
super().__init__("Request Timeout")
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from ._base_type import (
|
||||
FileTypes,
|
||||
HttpxFileTypes,
|
||||
HttpxRequestFiles,
|
||||
RequestFiles,
|
||||
)
|
||||
|
||||
|
||||
def is_file_content(obj: object) -> bool:
|
||||
return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
|
||||
|
||||
|
||||
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_file_content(file):
|
||||
if isinstance(file, os.PathLike):
|
||||
path = Path(file)
|
||||
return path.name, path.read_bytes()
|
||||
else:
|
||||
return file
|
||||
if isinstance(file, tuple):
|
||||
if isinstance(file[1], os.PathLike):
|
||||
return (file[0], Path(file[1]).read_bytes(), *file[2:])
|
||||
else:
|
||||
return (file[0], file[1], *file[2:])
|
||||
else:
|
||||
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
|
||||
|
||||
|
||||
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
if isinstance(files, Mapping):
|
||||
files = {key: _transform_file(file) for key, file in files.items()}
|
||||
elif isinstance(files, Sequence):
|
||||
files = [(key, _transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
|
||||
return files
|
||||
@@ -0,0 +1,377 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from httpx import URL, Timeout
|
||||
|
||||
from . import _errors
|
||||
from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data
|
||||
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||
from ._files import make_httpx_files
|
||||
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||
from ._response import HttpResponse
|
||||
from ._sse_client import StreamResponse
|
||||
from ._utils import flatten
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json; charset=UTF-8",
|
||||
}
|
||||
|
||||
|
||||
def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
|
||||
merged = {**map1, **map2}
|
||||
return {key: val for key, val in merged.items() if val is not None}
|
||||
|
||||
|
||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
||||
|
||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
|
||||
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
|
||||
|
||||
|
||||
class HttpClient:
|
||||
_client: httpx.Client
|
||||
_version: str
|
||||
_base_url: URL
|
||||
|
||||
timeout: Union[float, Timeout, None]
|
||||
_limits: httpx.Limits
|
||||
_has_custom_http_client: bool
|
||||
_default_stream_cls: type[StreamResponse[Any]] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: str,
|
||||
base_url: URL,
|
||||
timeout: Union[float, Timeout, None],
|
||||
custom_httpx_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
if timeout is None or isinstance(timeout, NotGiven):
|
||||
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
|
||||
timeout = custom_httpx_client.timeout
|
||||
else:
|
||||
timeout = ZHIPUAI_DEFAULT_TIMEOUT
|
||||
self.timeout = cast(Timeout, timeout)
|
||||
self._has_custom_http_client = bool(custom_httpx_client)
|
||||
self._client = custom_httpx_client or httpx.Client(
|
||||
base_url=base_url,
|
||||
timeout=self.timeout,
|
||||
limits=ZHIPUAI_DEFAULT_LIMITS,
|
||||
)
|
||||
self._version = version
|
||||
url = URL(url=base_url)
|
||||
if not url.raw_path.endswith(b"/"):
|
||||
url = url.copy_with(raw_path=url.raw_path + b"/")
|
||||
self._base_url = url
|
||||
self._custom_headers = custom_headers or {}
|
||||
|
||||
def _prepare_url(self, url: str) -> URL:
|
||||
|
||||
sub_url = URL(url)
|
||||
if sub_url.is_relative_url:
|
||||
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
|
||||
return self._base_url.copy_with(raw_path=request_raw_url)
|
||||
|
||||
return sub_url
|
||||
|
||||
@property
|
||||
def _default_headers(self):
|
||||
return \
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json; charset=UTF-8",
|
||||
"ZhipuAI-SDK-Ver": self._version,
|
||||
"source_type": "zhipu-sdk-python",
|
||||
"x-request-sdk": "zhipu-sdk-python",
|
||||
**self._auth_headers,
|
||||
**self._custom_headers,
|
||||
}
|
||||
|
||||
@property
|
||||
def _auth_headers(self):
|
||||
return {}
|
||||
|
||||
def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
|
||||
custom_headers = request_param.headers or {}
|
||||
headers_dict = _merge_map(self._default_headers, custom_headers)
|
||||
|
||||
httpx_headers = httpx.Headers(headers_dict)
|
||||
|
||||
return httpx_headers
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
request_param: ClientRequestParam
|
||||
) -> httpx.Request:
|
||||
kwargs: dict[str, Any] = {}
|
||||
json_data = request_param.json_data
|
||||
headers = self._prepare_headers(request_param)
|
||||
url = self._prepare_url(request_param.url)
|
||||
json_data = request_param.json_data
|
||||
if headers.get("Content-Type") == "multipart/form-data":
|
||||
headers.pop("Content-Type")
|
||||
|
||||
if json_data:
|
||||
kwargs["data"] = self._make_multipartform(json_data)
|
||||
|
||||
return self._client.build_request(
|
||||
headers=headers,
|
||||
timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
|
||||
method=request_param.method,
|
||||
url=url,
|
||||
json=json_data,
|
||||
files=request_param.files,
|
||||
params=request_param.params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
|
||||
items = []
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
for k, v in value.items():
|
||||
items.extend(self._object_to_formfata(f"{key}[{k}]", v))
|
||||
return items
|
||||
if isinstance(value, (list, tuple)):
|
||||
for v in value:
|
||||
items.extend(self._object_to_formfata(key + "[]", v))
|
||||
return items
|
||||
|
||||
def _primitive_value_to_str(val) -> str:
|
||||
# copied from httpx
|
||||
if val is True:
|
||||
return "true"
|
||||
elif val is False:
|
||||
return "false"
|
||||
elif val is None:
|
||||
return ""
|
||||
return str(val)
|
||||
|
||||
str_data = _primitive_value_to_str(value)
|
||||
|
||||
if not str_data:
|
||||
return []
|
||||
return [(key, str_data)]
|
||||
|
||||
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
||||
|
||||
items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
|
||||
|
||||
serialized: dict[str, object] = {}
|
||||
for key, value in items:
|
||||
if key in serialized:
|
||||
raise ValueError(f"存在重复的键: {key};")
|
||||
serialized[key] = value
|
||||
return serialized
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
response: httpx.Response,
|
||||
enable_stream: bool,
|
||||
request_param: ClientRequestParam,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> HttpResponse:
|
||||
|
||||
http_response = HttpResponse(
|
||||
raw_response=response,
|
||||
cast_type=cast_type,
|
||||
client=self,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls
|
||||
)
|
||||
return http_response.parse()
|
||||
|
||||
def _process_response_data(
|
||||
self,
|
||||
*,
|
||||
data: object,
|
||||
cast_type: type[ResponseT],
|
||||
response: httpx.Response,
|
||||
) -> ResponseT:
|
||||
if data is None:
|
||||
return cast(ResponseT, None)
|
||||
|
||||
try:
|
||||
if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
|
||||
return cast(ResponseT, cast_type.validate(data))
|
||||
|
||||
return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
|
||||
except pydantic.ValidationError as err:
|
||||
raise APIResponseValidationError(response=response, json_data=data) from err
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._client.is_closed
|
||||
|
||||
def close(self):
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def request(
|
||||
self,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
params: ClientRequestParam,
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
request = self._prepare_request(params)
|
||||
|
||||
try:
|
||||
response = self._client.send(
|
||||
request,
|
||||
stream=enable_stream,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.TimeoutException as err:
|
||||
raise APITimeoutError(request=request) from err
|
||||
except httpx.HTTPStatusError as err:
|
||||
err.response.read()
|
||||
# raise err
|
||||
raise self._make_status_error(err.response) from None
|
||||
|
||||
except Exception as err:
|
||||
raise err
|
||||
|
||||
return self._parse_response(
|
||||
cast_type=cast_type,
|
||||
request_param=params,
|
||||
response=response,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
enable_stream: bool = False,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="get", url=path, **options)
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
enable_stream=enable_stream
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
files: RequestFiles | None = None,
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path,
|
||||
**options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
) -> ResponseT:
|
||||
opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
files: RequestFiles | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files),
|
||||
**options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def _make_status_error(self, response) -> APIStatusError:
|
||||
response_text = response.text.strip()
|
||||
status_code = response.status_code
|
||||
error_msg = f"Error code: {status_code}, with error text {response_text}"
|
||||
|
||||
if status_code == 400:
|
||||
return _errors.APIRequestFailedError(message=error_msg, response=response)
|
||||
elif status_code == 401:
|
||||
return _errors.APIAuthenticationError(message=error_msg, response=response)
|
||||
elif status_code == 429:
|
||||
return _errors.APIReachLimitError(message=error_msg, response=response)
|
||||
elif status_code == 500:
|
||||
return _errors.APIInternalError(message=error_msg, response=response)
|
||||
elif status_code == 503:
|
||||
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
|
||||
return APIStatusError(message=error_msg, response=response)
|
||||
|
||||
|
||||
def make_user_request_input(
|
||||
max_retries: int | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers = None,
|
||||
query: Query | None = None,
|
||||
) -> UserRequestInput:
|
||||
options: UserRequestInput = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
options["headers"] = extra_headers
|
||||
if max_retries is not None:
|
||||
options["max_retries"] = max_retries
|
||||
if not isinstance(timeout, NotGiven):
|
||||
options['timeout'] = timeout
|
||||
if query is not None:
|
||||
options["params"] = query
|
||||
|
||||
return options
|
||||
@@ -0,0 +1,30 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import time
|
||||
|
||||
import cachetools.func
|
||||
import jwt
|
||||
|
||||
API_TOKEN_TTL_SECONDS = 3 * 60
|
||||
|
||||
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
|
||||
|
||||
|
||||
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
|
||||
def generate_token(apikey: str):
|
||||
try:
|
||||
api_key, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid api_key", e)
|
||||
|
||||
payload = {
|
||||
"api_key": api_key,
|
||||
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
ret = jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
return ret
|
||||
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Any, cast
|
||||
|
||||
import pydantic.generics
|
||||
from httpx import Timeout
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import (
|
||||
Unpack, ClassVar, TypedDict
|
||||
)
|
||||
|
||||
from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
|
||||
from ._utils import remove_notgiven_indict
|
||||
|
||||
|
||||
class UserRequestInput(TypedDict, total=False):
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
headers: Headers
|
||||
params: Query | None
|
||||
|
||||
|
||||
class ClientRequestParam():
|
||||
method: str
|
||||
url: str
|
||||
max_retries: Union[int, NotGiven] = NotGiven()
|
||||
timeout: Union[float, NotGiven] = NotGiven()
|
||||
headers: Union[Headers, NotGiven] = NotGiven()
|
||||
json_data: Union[Body, None] = None
|
||||
files: Union[HttpxRequestFiles, None] = None
|
||||
params: Query = {}
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def get_max_retries(self, max_retries) -> int:
|
||||
if isinstance(self.max_retries, NotGiven):
|
||||
return max_retries
|
||||
return self.max_retries
|
||||
|
||||
@classmethod
|
||||
def construct( # type: ignore
|
||||
cls,
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: Unpack[UserRequestInput],
|
||||
) -> ClientRequestParam :
|
||||
kwargs: dict[str, Any] = {
|
||||
key: remove_notgiven_indict(value) for key, value in values.items()
|
||||
}
|
||||
client = cls()
|
||||
client.__dict__.update(kwargs)
|
||||
|
||||
return client
|
||||
|
||||
model_construct = construct
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from typing_extensions import ParamSpec, get_origin, get_args
|
||||
|
||||
from ._base_type import NoneType
|
||||
from ._sse_client import StreamResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._http_client import HttpClient
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class HttpResponse(Generic[R]):
|
||||
_cast_type: type[R]
|
||||
_client: "HttpClient"
|
||||
_parsed: R | None
|
||||
_enable_stream: bool
|
||||
_stream_cls: type[StreamResponse[Any]]
|
||||
http_response: httpx.Response
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_response: httpx.Response,
|
||||
cast_type: type[R],
|
||||
client: "HttpClient",
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> None:
|
||||
self._cast_type = cast_type
|
||||
self._client = client
|
||||
self._parsed = None
|
||||
self._stream_cls = stream_cls
|
||||
self._enable_stream = enable_stream
|
||||
self.http_response = raw_response
|
||||
|
||||
def parse(self) -> R:
|
||||
self._parsed = self._parse()
|
||||
return self._parsed
|
||||
|
||||
def _parse(self) -> R:
|
||||
if self._enable_stream:
|
||||
self._parsed = cast(
|
||||
R,
|
||||
self._stream_cls(
|
||||
cast_type=cast(type, get_args(self._stream_cls)[0]),
|
||||
response=self.http_response,
|
||||
client=self._client
|
||||
)
|
||||
)
|
||||
return self._parsed
|
||||
cast_type = self._cast_type
|
||||
if cast_type is NoneType:
|
||||
return cast(R, None)
|
||||
http_response = self.http_response
|
||||
if cast_type == str:
|
||||
return cast(R, http_response.text)
|
||||
|
||||
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
|
||||
origin = get_origin(cast_type) or cast_type
|
||||
if content_type != "application/json":
|
||||
if issubclass(origin, pydantic.BaseModel):
|
||||
data = http_response.json()
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=http_response,
|
||||
)
|
||||
|
||||
return http_response.text
|
||||
|
||||
data = http_response.json()
|
||||
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=http_response,
|
||||
)
|
||||
|
||||
@property
|
||||
def headers(self) -> httpx.Headers:
|
||||
return self.http_response.headers
|
||||
|
||||
@property
|
||||
def http_request(self) -> httpx.Request:
|
||||
return self.http_response.request
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.http_response.status_code
|
||||
|
||||
@property
|
||||
def url(self) -> httpx.URL:
|
||||
return self.http_response.url
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.http_request.method
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.http_response.content
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.http_response.text
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
return self.http_response.http_version
|
||||
|
||||
@property
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
return self.http_response.elapsed
|
||||
@@ -0,0 +1,149 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Generic, Iterator, TYPE_CHECKING, Mapping
|
||||
|
||||
import httpx
|
||||
|
||||
from ._base_type import ResponseT
|
||||
from ._errors import APIResponseError
|
||||
|
||||
_FIELD_SEPARATOR = ":"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._http_client import HttpClient
|
||||
|
||||
|
||||
class StreamResponse(Generic[ResponseT]):
|
||||
|
||||
response: httpx.Response
|
||||
_cast_type: type[ResponseT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cast_type: type[ResponseT],
|
||||
response: httpx.Response,
|
||||
client: HttpClient,
|
||||
) -> None:
|
||||
self.response = response
|
||||
self._cast_type = cast_type
|
||||
self._data_process_func = client._process_response_data
|
||||
self._stream_chunks = self.__stream__()
|
||||
|
||||
def __next__(self) -> ResponseT:
|
||||
return self._stream_chunks.__next__()
|
||||
|
||||
def __iter__(self) -> Iterator[ResponseT]:
|
||||
for item in self._stream_chunks:
|
||||
yield item
|
||||
|
||||
def __stream__(self) -> Iterator[ResponseT]:
|
||||
|
||||
sse_line_parser = SSELineParser()
|
||||
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
|
||||
|
||||
for sse in iterator:
|
||||
if sse.data.startswith("[DONE]"):
|
||||
break
|
||||
|
||||
if sse.event is None:
|
||||
data = sse.json_data()
|
||||
if isinstance(data, Mapping) and data.get("error"):
|
||||
raise APIResponseError(
|
||||
message="An error occurred during streaming",
|
||||
request=self.response.request,
|
||||
json_data=data["error"],
|
||||
)
|
||||
|
||||
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
|
||||
for sse in iterator:
|
||||
pass
|
||||
|
||||
|
||||
class Event(object):
|
||||
def __init__(
|
||||
self,
|
||||
event: str | None = None,
|
||||
data: str | None = None,
|
||||
id: str | None = None,
|
||||
retry: int | None = None
|
||||
):
|
||||
self._event = event
|
||||
self._data = data
|
||||
self._id = id
|
||||
self._retry = retry
|
||||
|
||||
def __repr__(self):
|
||||
data_len = len(self._data) if self._data else 0
|
||||
return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
|
||||
|
||||
@property
|
||||
def event(self): return self._event
|
||||
|
||||
@property
|
||||
def data(self): return self._data
|
||||
|
||||
def json_data(self): return json.loads(self._data)
|
||||
|
||||
@property
|
||||
def id(self): return self._id
|
||||
|
||||
@property
|
||||
def retry(self): return self._retry
|
||||
|
||||
|
||||
class SSELineParser:
|
||||
_data: list[str]
|
||||
_event: str | None
|
||||
_retry: int | None
|
||||
_id: str | None
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._id = None
|
||||
self._retry = None
|
||||
|
||||
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
|
||||
for line in lines:
|
||||
line = line.rstrip('\n')
|
||||
if not line:
|
||||
if self._event is None and \
|
||||
not self._data and \
|
||||
self._id is None and \
|
||||
self._retry is None:
|
||||
continue
|
||||
sse_event = Event(
|
||||
event=self._event,
|
||||
data='\n'.join(self._data),
|
||||
id=self._id,
|
||||
retry=self._retry
|
||||
)
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._id = None
|
||||
self._retry = None
|
||||
|
||||
yield sse_event
|
||||
self.decode_line(line)
|
||||
|
||||
def decode_line(self, line: str):
|
||||
if line.startswith(":") or not line:
|
||||
return
|
||||
|
||||
field, _p, value = line.partition(":")
|
||||
|
||||
if value.startswith(' '):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
self._data.append(value)
|
||||
elif field == "event":
|
||||
self._event = value
|
||||
elif field == "retry":
|
||||
try:
|
||||
self._retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Iterable, TypeVar
|
||||
|
||||
from ._base_type import NotGiven
|
||||
|
||||
|
||||
def remove_notgiven_indict(obj):
|
||||
if obj is None or (not isinstance(obj, Mapping)):
|
||||
return obj
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||
return [item for sublist in t for item in sublist]
|
||||
@@ -0,0 +1,23 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .chat_completion import CompletionChoice, CompletionUsage
|
||||
|
||||
__all__ = ["AsyncTaskStatus"]
|
||||
|
||||
|
||||
class AsyncTaskStatus(BaseModel):
|
||||
id: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
task_status: Optional[str] = None
|
||||
|
||||
|
||||
class AsyncCompletion(BaseModel):
|
||||
id: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
task_status: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: CompletionUsage
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["Completion", "CompletionUsage"]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
|
||||
class CompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
function: Function
|
||||
type: str
|
||||
|
||||
|
||||
class CompletionMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: str
|
||||
tool_calls: Optional[List[CompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
index: int
|
||||
finish_reason: str
|
||||
message: CompletionMessage
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
model: Optional[str] = None
|
||||
created: Optional[int] = None
|
||||
choices: List[CompletionChoice]
|
||||
request_id: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionChunk",
|
||||
"Choice",
|
||||
"ChoiceDelta",
|
||||
"ChoiceDeltaFunctionCall",
|
||||
"ChoiceDeltaToolCall",
|
||||
"ChoiceDeltaToolCallFunction",
|
||||
]
|
||||
|
||||
|
||||
class ChoiceDeltaFunctionCall(BaseModel):
|
||||
arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCall(BaseModel):
|
||||
index: int
|
||||
id: Optional[str] = None
|
||||
function: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[str] = None
|
||||
index: int
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: Optional[str] = None
|
||||
choices: List[Choice]
|
||||
created: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
usage: Optional[CompletionUsage] = None
|
||||
@@ -0,0 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class Reference(TypedDict, total=False):
|
||||
enable: Optional[bool]
|
||||
search_query: Optional[str]
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from .chat.chat_completion import CompletionUsage
|
||||
__all__ = ["Embedding", "EmbeddingsResponded"]
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
object: str
|
||||
index: Optional[int] = None
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class EmbeddingsResponded(BaseModel):
|
||||
object: str
|
||||
data: List[Embedding]
|
||||
model: str
|
||||
usage: CompletionUsage
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FileObject"]
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
|
||||
id: Optional[str] = None
|
||||
bytes: Optional[int] = None
|
||||
created_at: Optional[int] = None
|
||||
filename: Optional[str] = None
|
||||
object: Optional[str] = None
|
||||
purpose: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
status_details: Optional[str] = None
|
||||
|
||||
|
||||
class ListOfFileObject(BaseModel):
|
||||
|
||||
object: Optional[str] = None
|
||||
data: List[FileObject]
|
||||
has_more: Optional[bool] = None
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .fine_tuning_job import FineTuningJob as FineTuningJob
|
||||
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
|
||||
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import List, Union, Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: str
|
||||
message: str
|
||||
param: Optional[str] = None
|
||||
|
||||
|
||||
class Hyperparameters(BaseModel):
|
||||
n_epochs: Union[str, int, None] = None
|
||||
|
||||
|
||||
class FineTuningJob(BaseModel):
|
||||
id: Optional[str] = None
|
||||
|
||||
request_id: Optional[str] = None
|
||||
|
||||
created_at: Optional[int] = None
|
||||
|
||||
error: Optional[Error] = None
|
||||
|
||||
fine_tuned_model: Optional[str] = None
|
||||
|
||||
finished_at: Optional[int] = None
|
||||
|
||||
hyperparameters: Optional[Hyperparameters] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
|
||||
object: Optional[str] = None
|
||||
|
||||
result_files: List[str]
|
||||
|
||||
status: str
|
||||
|
||||
trained_tokens: Optional[int] = None
|
||||
|
||||
training_file: str
|
||||
|
||||
validation_file: Optional[str] = None
|
||||
|
||||
|
||||
class ListOfFineTuningJob(BaseModel):
|
||||
object: Optional[str] = None
|
||||
data: List[FineTuningJob]
|
||||
has_more: Optional[bool] = None
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import List, Union, Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
epoch: Optional[Union[str, int, float]] = None
|
||||
current_steps: Optional[int] = None
|
||||
total_steps: Optional[int] = None
|
||||
elapsed_time: Optional[str] = None
|
||||
remaining_time: Optional[str] = None
|
||||
trained_tokens: Optional[int] = None
|
||||
loss: Optional[Union[str, int, float]] = None
|
||||
eval_loss: Optional[Union[str, int, float]] = None
|
||||
acc: Optional[Union[str, int, float]] = None
|
||||
eval_acc: Optional[Union[str, int, float]] = None
|
||||
learning_rate: Optional[Union[str, int, float]] = None
|
||||
|
||||
|
||||
class JobEvent(BaseModel):
|
||||
object: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
created_at: Optional[int] = None
|
||||
level: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
data: Optional[Metric] = None
|
||||
|
||||
|
||||
class FineTuningJobEvent(BaseModel):
|
||||
object: Optional[str] = None
|
||||
data: List[JobEvent]
|
||||
has_more: Optional[bool] = None
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
__all__ = ["Hyperparameters"]
|
||||
|
||||
|
||||
class Hyperparameters(TypedDict, total=False):
|
||||
batch_size: Union[Literal["auto"], int]
|
||||
|
||||
learning_rate_multiplier: Union[Literal["auto"], float]
|
||||
|
||||
n_epochs: Union[Literal["auto"], int]
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["GeneratedImage", "ImagesResponded"]
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
b64_json: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
revised_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class ImagesResponded(BaseModel):
|
||||
created: int
|
||||
data: List[GeneratedImage]
|
||||
@@ -31,6 +31,7 @@ import mimetypes
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_builtin_providers = {}
|
||||
_builtin_tools_labels = {}
|
||||
|
||||
class ToolManager:
|
||||
@staticmethod
|
||||
@@ -233,7 +234,7 @@ class ToolManager:
|
||||
if len(_builtin_providers) > 0:
|
||||
return list(_builtin_providers.values())
|
||||
|
||||
builtin_providers = []
|
||||
builtin_providers: List[BuiltinToolProviderController] = []
|
||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
@@ -264,8 +265,30 @@ class ToolManager:
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
_builtin_providers[provider.identity.name] = provider
|
||||
for tool in provider.get_tools():
|
||||
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||
|
||||
return builtin_providers
|
||||
|
||||
@staticmethod
|
||||
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
|
||||
"""
|
||||
get the tool label
|
||||
|
||||
:param tool_name: the name of the tool
|
||||
|
||||
:return: the label of the tool
|
||||
"""
|
||||
global _builtin_tools_labels
|
||||
if len(_builtin_tools_labels) == 0:
|
||||
# init the builtin providers
|
||||
ToolManager.list_builtin_providers()
|
||||
|
||||
if tool_name not in _builtin_tools_labels:
|
||||
return None
|
||||
|
||||
return _builtin_tools_labels[tool_name]
|
||||
|
||||
@staticmethod
|
||||
def user_list_providers(
|
||||
user_id: str,
|
||||
|
||||
@@ -49,10 +49,11 @@ agent_thought_fields = {
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
'files': fields.List(fields.String)
|
||||
'files': fields.List(fields.String),
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
|
||||
@@ -36,6 +36,7 @@ agent_thought_fields = {
|
||||
'position': fields.Integer,
|
||||
'thought': fields.String,
|
||||
'tool': fields.String,
|
||||
'tool_labels': fields.Raw,
|
||||
'tool_input': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'observation': fields.String,
|
||||
|
||||
277
api/libs/gmpy2_pkcs10aep_cipher.py
Normal file
277
api/libs/gmpy2_pkcs10aep_cipher.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP
|
||||
#
|
||||
# ===================================================================
|
||||
# The contents of this file are dedicated to the public domain. To
|
||||
# the extent that dedication to the public domain is not available,
|
||||
# everyone is granted a worldwide, perpetual, royalty-free,
|
||||
# non-exclusive license to exercise all rights associated with the
|
||||
# contents of this file for any purpose whatsoever.
|
||||
# No rights are reserved.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# ===================================================================
|
||||
|
||||
from Crypto.Signature.pss import MGF1
|
||||
import Crypto.Hash.SHA1
|
||||
|
||||
from Crypto.Util.py3compat import bord, _copy_bytes
|
||||
import Crypto.Util.number
|
||||
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
||||
from Crypto.Util.strxor import strxor
|
||||
from Crypto import Random
|
||||
from hashlib import sha1
|
||||
import gmpy2
|
||||
|
||||
class PKCS1OAEP_Cipher:
|
||||
"""Cipher object for PKCS#1 v1.5 OAEP.
|
||||
Do not create directly: use :func:`new` instead."""
|
||||
|
||||
def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
|
||||
"""Initialize this PKCS#1 OAEP cipher object.
|
||||
|
||||
:Parameters:
|
||||
key : an RSA key object
|
||||
If a private half is given, both encryption and decryption are possible.
|
||||
If a public half is given, only encryption is possible.
|
||||
hashAlgo : hash object
|
||||
The hash function to use. This can be a module under `Crypto.Hash`
|
||||
or an existing hash object created from any of such modules. If not specified,
|
||||
`Crypto.Hash.SHA1` is used.
|
||||
mgfunc : callable
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the lenth of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
label : bytes/bytearray/memoryview
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
an empty string is used. Specifying a label does not improve
|
||||
security.
|
||||
randfunc : callable
|
||||
A function that returns random bytes.
|
||||
|
||||
:attention: Modify the mask generation function only if you know what you are doing.
|
||||
Sender and receiver must use the same one.
|
||||
"""
|
||||
self._key = key
|
||||
|
||||
if hashAlgo:
|
||||
self._hashObj = hashAlgo
|
||||
else:
|
||||
self._hashObj = Crypto.Hash.SHA1
|
||||
|
||||
if mgfunc:
|
||||
self._mgf = mgfunc
|
||||
else:
|
||||
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
|
||||
|
||||
self._label = _copy_bytes(None, None, label)
|
||||
self._randfunc = randfunc
|
||||
|
||||
def can_encrypt(self):
|
||||
"""Legacy function to check if you can call :meth:`encrypt`.
|
||||
|
||||
.. deprecated:: 3.0"""
|
||||
return self._key.can_encrypt()
|
||||
|
||||
def can_decrypt(self):
|
||||
"""Legacy function to check if you can call :meth:`decrypt`.
|
||||
|
||||
.. deprecated:: 3.0"""
|
||||
return self._key.can_decrypt()
|
||||
|
||||
def encrypt(self, message):
|
||||
"""Encrypt a message with PKCS#1 OAEP.
|
||||
|
||||
:param message:
|
||||
The message to encrypt, also known as plaintext. It can be of
|
||||
variable length, but not longer than the RSA modulus (in bytes)
|
||||
minus 2, minus twice the hash output size.
|
||||
For instance, if you use RSA 2048 and SHA-256, the longest message
|
||||
you can encrypt is 190 byte long.
|
||||
:type message: bytes/bytearray/memoryview
|
||||
|
||||
:returns: The ciphertext, as large as the RSA modulus.
|
||||
:rtype: bytes
|
||||
|
||||
:raises ValueError:
|
||||
if the message is too long.
|
||||
"""
|
||||
|
||||
# See 7.1.1 in RFC3447
|
||||
modBits = Crypto.Util.number.size(self._key.n)
|
||||
k = ceil_div(modBits, 8) # Convert from bits to bytes
|
||||
hLen = self._hashObj.digest_size
|
||||
mLen = len(message)
|
||||
|
||||
# Step 1b
|
||||
ps_len = k - mLen - 2 * hLen - 2
|
||||
if ps_len < 0:
|
||||
raise ValueError("Plaintext is too long.")
|
||||
# Step 2a
|
||||
lHash = sha1(self._label).digest()
|
||||
# Step 2b
|
||||
ps = b'\x00' * ps_len
|
||||
# Step 2c
|
||||
db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
|
||||
# Step 2d
|
||||
ros = self._randfunc(hLen)
|
||||
# Step 2e
|
||||
dbMask = self._mgf(ros, k-hLen-1)
|
||||
# Step 2f
|
||||
maskedDB = strxor(db, dbMask)
|
||||
# Step 2g
|
||||
seedMask = self._mgf(maskedDB, hLen)
|
||||
# Step 2h
|
||||
maskedSeed = strxor(ros, seedMask)
|
||||
# Step 2i
|
||||
em = b'\x00' + maskedSeed + maskedDB
|
||||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||
# Step 3c (I2OSP)
|
||||
c = long_to_bytes(m_int, k)
|
||||
return c
|
||||
|
||||
def decrypt(self, ciphertext):
|
||||
"""Decrypt a message with PKCS#1 OAEP.
|
||||
|
||||
:param ciphertext: The encrypted message.
|
||||
:type ciphertext: bytes/bytearray/memoryview
|
||||
|
||||
:returns: The original message (plaintext).
|
||||
:rtype: bytes
|
||||
|
||||
:raises ValueError:
|
||||
if the ciphertext has the wrong length, or if decryption
|
||||
fails the integrity check (in which case, the decryption
|
||||
key is probably wrong).
|
||||
:raises TypeError:
|
||||
if the RSA key has no private half (i.e. you are trying
|
||||
to decrypt using a public key).
|
||||
"""
|
||||
# See 7.1.2 in RFC3447
|
||||
modBits = Crypto.Util.number.size(self._key.n)
|
||||
k = ceil_div(modBits,8) # Convert from bits to bytes
|
||||
hLen = self._hashObj.digest_size
|
||||
# Step 1b and 1c
|
||||
if len(ciphertext) != k or k<hLen+2:
|
||||
raise ValueError("Ciphertext with incorrect length.")
|
||||
# Step 2a (O2SIP)
|
||||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
# m_int = self._key._decrypt(ct_int)
|
||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3a
|
||||
lHash = sha1(self._label).digest()
|
||||
# Step 3b
|
||||
y = em[0]
|
||||
# y must be 0, but we MUST NOT check it here in order not to
|
||||
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
|
||||
maskedSeed = em[1:hLen+1]
|
||||
maskedDB = em[hLen+1:]
|
||||
# Step 3c
|
||||
seedMask = self._mgf(maskedDB, hLen)
|
||||
# Step 3d
|
||||
seed = strxor(maskedSeed, seedMask)
|
||||
# Step 3e
|
||||
dbMask = self._mgf(seed, k-hLen-1)
|
||||
# Step 3f
|
||||
db = strxor(maskedDB, dbMask)
|
||||
# Step 3g
|
||||
one_pos = hLen + db[hLen:].find(b'\x01')
|
||||
lHash1 = db[:hLen]
|
||||
invalid = bord(y) | int(one_pos < hLen)
|
||||
hash_compare = strxor(lHash1, lHash)
|
||||
for x in hash_compare:
|
||||
invalid |= bord(x)
|
||||
for x in db[hLen:one_pos]:
|
||||
invalid |= bord(x)
|
||||
if invalid != 0:
|
||||
raise ValueError("Incorrect decryption.")
|
||||
# Step 4
|
||||
return db[one_pos + 1:]
|
||||
|
||||
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
|
||||
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
|
||||
|
||||
:param key:
|
||||
The key object to use to encrypt or decrypt the message.
|
||||
Decryption is only possible with a private RSA key.
|
||||
:type key: RSA key object
|
||||
|
||||
:param hashAlgo:
|
||||
The hash function to use. This can be a module under `Crypto.Hash`
|
||||
or an existing hash object created from any of such modules.
|
||||
If not specified, `Crypto.Hash.SHA1` is used.
|
||||
:type hashAlgo: hash object
|
||||
|
||||
:param mgfunc:
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the lenth of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
:type mgfunc: callable
|
||||
|
||||
:param label:
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
an empty string is used. Specifying a label does not improve
|
||||
security.
|
||||
:type label: bytes/bytearray/memoryview
|
||||
|
||||
:param randfunc:
|
||||
A function that returns random bytes.
|
||||
The default is `Random.get_random_bytes`.
|
||||
:type randfunc: callable
|
||||
"""
|
||||
|
||||
if randfunc is None:
|
||||
randfunc = Random.get_random_bytes
|
||||
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
|
||||
|
||||
|
||||
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
|
||||
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
|
||||
|
||||
:param key:
|
||||
The key object to use to encrypt or decrypt the message.
|
||||
Decryption is only possible with a private RSA key.
|
||||
:type key: RSA key object
|
||||
|
||||
:param hashAlgo:
|
||||
The hash function to use. This can be a module under `Crypto.Hash`
|
||||
or an existing hash object created from any of such modules.
|
||||
If not specified, `Crypto.Hash.SHA1` is used.
|
||||
:type hashAlgo: hash object
|
||||
|
||||
:param mgfunc:
|
||||
A mask generation function that accepts two parameters: a string to
|
||||
use as seed, and the lenth of the mask to generate, in bytes.
|
||||
If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
|
||||
:type mgfunc: callable
|
||||
|
||||
:param label:
|
||||
A label to apply to this particular encryption. If not specified,
|
||||
an empty string is used. Specifying a label does not improve
|
||||
security.
|
||||
:type label: bytes/bytearray/memoryview
|
||||
|
||||
:param randfunc:
|
||||
A function that returns random bytes.
|
||||
The default is `Random.get_random_bytes`.
|
||||
:type randfunc: callable
|
||||
"""
|
||||
|
||||
if randfunc is None:
|
||||
randfunc = Random.get_random_bytes
|
||||
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user