Compare commits

...

54 Commits
0.6.4 ... 0.6.5

Author SHA1 Message Date
takatost
34bfb715e1 fix: citations always appear in the chatflow app (#3844) 2024-04-25 18:31:38 +08:00
Joel
019d7069f8 fix: debug run not show total right tokens (#3843) 2024-04-25 18:22:30 +08:00
Bowen Liang
c54fcfb45d extract enum type for tenant account role (#3788) 2024-04-25 18:20:08 +08:00
zxhlyh
cde87cb225 fix: model parameter default value (#3841) 2024-04-25 18:04:37 +08:00
takatost
12435774ca feat: query prompt template support in chatflow (#3791)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-04-25 18:01:53 +08:00
Henrybit
80b9507e7a feat: add aliyun oss storage (#3690)
Co-authored-by: henrybit <qipenghui3056@sina.com>
2024-04-25 16:57:19 +08:00
takatost
0ac0f0ffd0 version to 0.6.5 (#3834) 2024-04-25 16:50:37 +08:00
KVOJJJin
3d14aba4b4 Fix: event of click away in message-log-modal (#3828) 2024-04-25 15:58:03 +08:00
Leon cap
64f694865c Update EN,KL,JA,FR,ES documentation Llma2 to Llama3 model support (#3827) 2024-04-25 15:52:00 +08:00
zxhlyh
d36b728088 fix: workflow sync data (#3824) 2024-04-25 14:02:06 +08:00
KVOJJJin
1a7b4c42ab fix: event of keyboard "enter" in text generator app (#3823) 2024-04-25 13:58:06 +08:00
Joel
2a64ce740e chore: remove anthropic pay entrance (#3822) 2024-04-25 13:18:59 +08:00
呆萌闷油瓶
78988ed60e fix:still enable SSL verification when using qdrant based on HTTP protocol (#3805) 2024-04-25 13:04:31 +08:00
Yeuoly
2832adda88 fix: missing url field when searching special keywords (#3820) 2024-04-25 12:33:58 +08:00
takatost
a4e4fb4094 fix: credentials validate failed for groqcloud model provider (#3817) 2024-04-25 12:09:44 +08:00
YidaHu
777ec64635 feat: add log_file environment variable (#3793) 2024-04-24 21:55:14 +08:00
Bowen Liang
9cec8c1750 test: add unit tests for vector stores of Milvus, Qdrant and Weaviate (#3688) 2024-04-24 21:52:42 +08:00
Bowen Liang
8ca5aa1190 use pymilvus 2.3.7 (#3790) 2024-04-24 18:37:08 +08:00
takatost
4d8f1b9ca4 feat: test all unit tests (#3787)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-04-24 17:33:01 +08:00
takatost
3da179f77b feat: add conversation_id and user_id in chatflow/workflow system vars (#3771)
Co-authored-by: Joel <iamjoel007@gmail.com>
2024-04-24 17:20:01 +08:00
Bowen Liang
a34e8cb0bd test: add test for PKCS1OAEP_Cipher with gmpy2 (#3760) 2024-04-24 17:15:31 +08:00
KVOJJJin
b249767c5c Fix: redirection of app remove (#3770) 2024-04-24 17:11:51 +08:00
Joel
89a7434565 fix: handle inputs show the focus ui together in tools node (#3763) 2024-04-24 15:53:07 +08:00
ugyuji
3b537cbdeb fix: endpoint for 'Update a document from a file' (#3751) 2024-04-24 15:25:53 +08:00
zxhlyh
731464f5b8 fix: workflow sync (#3756) 2024-04-24 15:19:19 +08:00
Joel
1ad70f8721 feat: support prompt messages sorting (#3757) 2024-04-24 15:09:01 +08:00
takatost
2ea8c73cd8 fix: type num of variable converted to str (#3758) 2024-04-24 15:07:56 +08:00
Jyong
f257f2c396 Knowledge optimization (#3755)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
2024-04-24 15:02:29 +08:00
Joel
3cd8e6f5c6 fix: llm editor readonly cover error (#3752) 2024-04-24 13:28:22 +08:00
KVOJJJin
0715db7681 chore: add selector for use app store (#3746) 2024-04-24 13:07:20 +08:00
zxhlyh
a39de8a686 fix: workflow restore (#3750) 2024-04-24 13:05:33 +08:00
Bowen Liang
ccaf335466 fix: rollback gmpy2 to 2.1.5 (#3745) 2024-04-24 12:53:23 +08:00
legao
40e36e9b52 fix: toggling AppDetailNav causes unnecessary component rerenders (#3718) 2024-04-24 12:07:28 +08:00
zxhlyh
9eebe9d54e fix: workflow node variable (#3743) 2024-04-24 11:41:12 +08:00
crazywoola
a23a191615 feat: add copy button to code (#3719) 2024-04-24 09:34:51 +08:00
Leo Q
7d9c5586f9 Update "@formatjs/intl-localematcher" to version 0.5.4 in package.json (#3726) 2024-04-24 09:06:23 +08:00
Ikko Eltociear Ashimine
f07c89bba4 Update README_JA.md (#3727) 2024-04-24 09:04:27 +08:00
1102
59cba930e5 bedrock llm Model file name change (#3714)
Co-authored-by: heshunchang <shuncanghe@clouditera.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-23 18:57:34 +08:00
zxhlyh
39ae56e136 fix: workflow connection (#3713) 2024-04-23 18:02:15 +08:00
Joel
f92130338b feat: prompt editor support auto height by content height and fix some bugs (#3712) 2024-04-23 17:46:59 +08:00
Bowen Liang
2867d29021 fix: milvus usage with create_collection (#3683) 2024-04-23 17:37:40 +08:00
呆萌闷油瓶
f76ac8bdee enhance:speedup xinference audio transcription (#3636) 2024-04-23 17:09:30 +08:00
zxhlyh
83caffe000 fix: workflow restore (#3711) 2024-04-23 17:02:23 +08:00
Luvian77
96160837d2 fix: cannot change file uploader method (#3710) 2024-04-23 17:02:12 +08:00
Yeuoly
3480f1c59e refactor: tool parameter cache (#3703) 2024-04-23 15:22:42 +08:00
zxhlyh
65ac4f69af fix: workflow shortcuts (#3701) 2024-04-23 14:45:57 +08:00
Yeuoly
2c50fab3dd fix: skip dataset icon (#3696) 2024-04-23 12:41:41 +08:00
Carson Kahn
9525ccac4f Localize links to localized READMEs (#3689) 2024-04-23 09:30:32 +08:00
Richards Tu
ff76c4bd5d Add new tool: Judge0 CE (#3684)
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-23 09:07:21 +08:00
Luvian77
5dacf77627 fix: Added prevention of click event propagation for overlay layer (#3666)
Co-authored-by: crazywoola <427733928@qq.com>
2024-04-22 19:53:20 +08:00
Yeuoly
2a213c6af7 fix: incorrect type parser (#3682) 2024-04-22 19:32:41 +08:00
Bowen Liang
b2535e7db6 chore: update description of code interpreter tool (#3679)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-04-22 19:19:16 +08:00
longzhihun
28236147ee feat: add support for bedrock Mistral AI model (#3676)
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
2024-04-22 17:24:02 +08:00
Chenhe Gu
4969783383 add groq llama3 (#3673) 2024-04-22 15:21:09 +08:00
225 changed files with 5637 additions and 1586 deletions

View File

@@ -10,7 +10,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version:
- "3.10"
- "3.11"
env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
@@ -35,10 +37,26 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Install APT packages
uses: awalsh128/cache-apt-pkgs-action@v1
- name: Set up Weaviate
uses: hoverkraft-tech/compose-action@v2.0.0
with:
packages: ffmpeg
compose-file: docker/docker-compose.middleware.yaml
services: weaviate
- name: Set up Qdrant
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.qdrant.yaml
services: qdrant
- name: Set up Milvus
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: docker/docker-compose.milvus.yaml
services: |
etcd
minio
milvus-standalone
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
@@ -52,6 +70,9 @@ jobs:
- name: Install dependencies
run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt
- name: Run Unit tests
run: dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: dev/pytest/pytest_model_runtime.sh
@@ -60,3 +81,6 @@ jobs:
- name: Run Workflow
run: dev/pytest/pytest_workflow.sh
- name: Run Vector Stores
run: dev/pytest/pytest_vdb.sh

View File

@@ -29,12 +29,12 @@
</p>
<p align="center">
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
</p>
#
@@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
**2. Comprehensive model support**:
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)

View File

@@ -54,7 +54,7 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto.
**2. Soporte de modelos completo**:
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama2 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
![proveedores-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)

View File

@@ -54,7 +54,7 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
**2. Prise en charge complète des modèles**:
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama2, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)

View File

@@ -55,9 +55,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
**2. 網羅的なモデルサポート**:
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama2、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs
.dify.ai/getting-started/readme/model-providers)をご覧ください。
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)
@@ -155,9 +153,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。
- **エンタープライズ/組織向けのDify</br>**
追加のエンタープライズ向け機能を提供しています。[こちらからミーティ
ングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
追加のエンタープライズ向け機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
> AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。

View File

@@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
**2. Comprehensive model support**:
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)

View File

@@ -52,6 +52,11 @@ AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@@ -160,3 +165,6 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
# API Tool configuration
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# Log file path
LOG_FILE=

View File

@@ -104,7 +104,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.6.4"
self.CURRENT_VERSION = "0.6.5"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -208,6 +208,10 @@ class Config:
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
# ------------------------
# Vector Store Configurations.

View File

@@ -53,5 +53,8 @@ from .explore import (
workflow,
)
# Import tag controllers
from .tag import tags
# Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace

View File

@@ -1,17 +1,16 @@
import json
import uuid
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
@@ -20,6 +19,7 @@ from fields.app_fields import (
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_service import AppService
from services.tag_service import TagService
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
@@ -29,21 +29,29 @@ class AppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_pagination_fields)
def get(self):
"""Get app list"""
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(',')]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
parser.add_argument('name', type=str, location='args', required=False)
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
if not app_pagination:
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
return app_pagination
return marshal(app_pagination, app_pagination_fields)
@setup_required
@login_required
@@ -108,43 +116,9 @@ class AppApi(Resource):
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
# get original app model config
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
app_service = AppService()
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
app_model = app_service.get_app(app_model)
return app_model

View File

@@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
@@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
except Exception as e:
continue
@@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception as e:
@@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
manager.delete_tool_parameters_cache()
@@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
for masked_key, masked_value in masked_parameter_map[key].items():
if masked_key in agent_tool_entity.tool_parameters and \
agent_tool_entity.tool_parameters[masked_key] == masked_value:
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters
if agent_tool_entity.tool_parameters:

View File

@@ -48,11 +48,14 @@ class DatasetListApi(Resource):
limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist('ids')
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user)
current_user.current_tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
@@ -184,6 +187,10 @@ class DatasetApi(Resource):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('embedding_model', type=str,
location='json', help='Invalid embedding model.')
parser.add_argument('embedding_model_provider', type=str,
location='json', help='Invalid embedding model provider.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()
@@ -506,10 +513,27 @@ class DatasetRetrievalSettingMockApi(Resource):
else:
raise ValueError("Unsupported vector db type.")
class DatasetErrorDocs(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {
'data': [marshal(item, document_status_fields) for item in results],
'total': len(results)
}, 200
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')

View File

@@ -1,3 +1,4 @@
import logging
from datetime import datetime, timezone
from flask import request
@@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
location='json')
parser.add_argument('data_source', type=dict, required=False, location='json')
parser.add_argument('process_rule', type=dict, required=False, location='json')
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
@@ -883,6 +884,49 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentRetryApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id):
"""retry document."""
parser = reqparse.RequestParser()
parser.add_argument('document_ids', type=list, required=True, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound('Dataset not found.')
for document_id in args['document_ids']:
try:
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
raise NotFound("Document Not Exists.")
# 403 if document is archived
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# 400 if document is completed
if document.indexing_status == 'completed':
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception as e:
logging.error(f"Document {document_id} retry failed: {str(e)}")
continue
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
return {'result': 'success'}, 204
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@@ -908,3 +952,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')

View File

@@ -0,0 +1,159 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.tag_fields import tag_fields
from libs.login import login_required
from models.model import Tag
from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 50 characters.')
return name
class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(tag_fields)
def get(self):
tag_type = request.args.get('type', type=str)
keyword = request.args.get('keyword', default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
return tags, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
tag = TagService.save_tags(args)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': 0
}
return response, 200
class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='Name must be between 1 to 50 characters.',
type=_validate_name)
args = parser.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
response = {
'id': tag.id,
'name': tag.name,
'type': tag.type,
'binding_count': binding_count
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, tag_id):
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
TagService.delete_tag(tag_id)
return 200
class TagBindingCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
help='Tag IDs is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.save_tag_binding(args)
return 200
class TagBindingDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('tag_id', type=str, nullable=False, required=True,
help='Tag ID is required.')
parser.add_argument('target_id', type=str, nullable=False, required=True,
help='Target ID is required.')
parser.add_argument('type', type=str, location='json',
choices=Tag.TAG_TYPE_LIST,
nullable=True,
help='Invalid tag type.')
args = parser.parse_args()
TagService.delete_tag_binding(args)
return 200
api.add_resource(TagListApi, '/tags')
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')

View File

@@ -9,7 +9,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.login import login_required
from models.account import Account
from models.account import Account, TenantAccountRole
from services.account_service import RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
@@ -43,7 +43,7 @@ class MemberInviteEmailApi(Resource):
invitee_emails = args['emails']
invitee_role = args['role']
interface_language = args['language']
if invitee_role not in ['admin', 'normal']:
if invitee_role not in [TenantAccountRole.ADMIN, TenantAccountRole.NORMAL]:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
inviter = current_user

View File

@@ -11,6 +11,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import TenantAccountRole
from services.model_provider_service import ModelProviderService
@@ -94,7 +95,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
@@ -125,7 +126,7 @@ class ModelProviderModelApi(Resource):
@login_required
@account_initialization_required
def delete(self, provider: str):
if current_user.current_tenant.current_role not in ['admin', 'owner']:
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id

View File

@@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist('tag_ids')
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
tenant_id, current_user, search, tag_ids)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(

View File

@@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
)
tool_entity.load_variables(self.variables_pool)

View File

@@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App, Conversation, Message
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner):
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
@@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner):
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION: conversation.id,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks
)

View File

@@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION: conversation.id,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
}
self._task_state = AdvancedChatTaskState(

View File

@@ -23,20 +23,28 @@ class BaseAppGenerator:
value = user_inputs[variable]
if value:
if not isinstance(value, str):
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
if '.' in value:
value = float(value)
else:
value = int(value)
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
if value and isinstance(value, str):
filtered_inputs[variable] = value.replace('\x00', '')
else:
filtered_inputs[variable] = value if value else None
return filtered_inputs

View File

@@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.model import App
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -36,6 +36,14 @@ class WorkflowAppRunner:
app_config = application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
@@ -67,7 +75,8 @@ class WorkflowAppRunner:
else UserFrom.END_USER,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks
)

View File

@@ -71,9 +71,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._workflow_system_variables = {
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.USER_ID: user_id
}
self._task_state = WorkflowTaskState()

View File

@@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
# app config
app_config: AppConfig
inputs: dict[str, str]
inputs: dict[str, Any]
files: list[FileVar] = []
user_id: str

View File

@@ -118,7 +118,8 @@ class MessageCycleManage:
:param event: event
:return:
"""
self._task_state.metadata['retriever_resources'] = event.retriever_resources
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
def _get_response_metadata(self) -> dict:
"""

View File

@@ -1,10 +1,13 @@
import json
import re
from base64 import b64encode
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = """
import jinja2
from json import loads
from base64 import b64decode
template = jinja2.Template('''{{code}}''')
@@ -12,7 +15,8 @@ def main(**inputs):
return template.render(**inputs)
# execute main function, and return the result
output = main(**{{inputs}})
inputs = b64decode('{{inputs}}').decode('utf-8')
output = main(**loads(inputs))
result = f'''<<RESULT>>{output}<<RESULT>>'''
@@ -39,6 +43,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
JINJA2_PRELOAD = f"""
import jinja2
from base64 import b64decode
def _jinja2_preload_():
# prepare jinja2 environment, load template and render before to avoid sandbox issue
@@ -60,9 +65,11 @@ class Jinja2TemplateTransformer(TemplateTransformer):
:return:
"""
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
# transform jinja2 template to python code
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
runner = runner.replace('{{inputs}}', inputs_str)
return runner, JINJA2_PRELOAD

View File

@@ -1,17 +1,22 @@
import json
import re
from base64 import b64encode
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = """# declare main function here
{{code}}
from json import loads, dumps
from base64 import b64decode
# execute main function, and return the result
# inputs is a dict, and it
output = main(**{{inputs}})
inputs = b64decode('{{inputs}}').decode('utf-8')
output = main(**json.loads(inputs))
# convert output to json and print
output = json.dumps(output, indent=4)
output = dumps(output, indent=4)
result = f'''<<RESULT>>
{output}
@@ -54,7 +59,7 @@ class PythonTemplateTransformer(TemplateTransformer):
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
# replace code and inputs
runner = PYTHON_RUNNER.replace('{{code}}', code)

View File

@@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType,
identity_id: str
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
def get(self) -> Optional[dict]:
"""

View File

@@ -10,3 +10,6 @@
- cohere.command-text-v14
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2402-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-7b-instruct-v0:2

View File

@@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
human_prompt_prefix = "\n[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = ""
elif model_prefix == "mistral":
human_prompt_prefix = "<s>[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = "\n\nAssistant:"
elif model_prefix == "amazon":
human_prompt_prefix = "\n\nUser:"
@@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
if model_parameters.get("countPenalty"):
payload["countPenalty"] = {model_parameters.get("countPenalty")}
elif model_prefix == "mistral":
payload["temperature"] = model_parameters.get("temperature")
payload["top_p"] = model_parameters.get("top_p")
payload["max_tokens"] = model_parameters.get("max_tokens")
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload["stop"] = stop[:10] if stop else []
elif model_prefix == "anthropic":
payload = { **model_parameters }
@@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
output = response_body.get("generation").strip('\n')
prompt_tokens = response_body.get("prompt_token_count")
completion_tokens = response_body.get("generation_token_count")
elif model_prefix == "mistral":
output = response_body.get("outputs")[0].get("text")
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
content_delta = payload.get("text")
finish_reason = payload.get("finish_reason")
elif model_prefix == "mistral":
content_delta = payload.get('outputs')[0].get("text")
finish_reason = payload.get('outputs')[0].get("stop_reason")
elif model_prefix == "meta":
content_delta = payload.get("generation").strip('\n')
finish_reason = payload.get("stop_reason")

View File

@@ -0,0 +1,39 @@
model: mistral.mistral-7b-instruct-v0:2
label:
en_US: Mistral 7B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.5
- name: top_p
use_template: top_p
required: false
default: 0.9
- name: top_k
use_template: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 50
max: 200
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 8192
pricing:
input: '0.00015'
output: '0.0002'
unit: '0.00001'
currency: USD

View File

@@ -0,0 +1,27 @@
model: mistral.mistral-large-2402-v1:0
label:
en_US: Mistral Large
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.7
- name: top_p
use_template: top_p
required: false
default: 1
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 4096
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@@ -0,0 +1,39 @@
model: mistral.mixtral-8x7b-instruct-v0:1
label:
en_US: Mixtral 8X7B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.5
- name: top_p
use_template: top_p
required: false
default: 0.9
- name: top_k
use_template: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 50
max: 200
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 8192
pricing:
input: '0.00045'
output: '0.0007'
unit: '0.00001'
currency: USD

View File

@@ -19,7 +19,7 @@ class GroqProvider(ModelProvider):
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama2-70b-4096',
model='llama3-8b-8192',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@@ -0,0 +1,25 @@
model: llama3-70b-8192
label:
zh_Hans: Llama-3-70B-8192
en_US: Llama-3-70B-8192
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,25 @@
model: llama3-8b-8192
label:
zh_Hans: Llama-3-8B-8192
en_US: Llama-3-8B-8192
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.59'
output: '0.79'
unit: '0.000001'
currency: USD

View File

@@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a audio model')
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
@@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
response = xinference_client.transcriptions(
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
response = handle.transcriptions(
audio=file,
language = language,
prompt = prompt,

View File

@@ -31,7 +31,10 @@ class AdvancedPromptTransform(PromptTransform):
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
inputs = {key: str(value) for key, value in inputs.items()}
prompt_messages = []
model_mode = ModelMode.value_of(model_config.mode)
@@ -51,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_template=prompt_template,
inputs=inputs,
query=query,
query_prompt_template=query_prompt_template,
files=files,
context=context,
memory_config=memory_config,
@@ -119,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
model_config: ModelConfigWithCredentialsEntity,
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
"""
Get chat model prompt messages.
"""
@@ -146,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt))
if query and query_prompt_template:
prompt_template = PromptTemplateParser(
template=query_prompt_template,
with_variable_tmpl=self.with_variable_tmpl
)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt_inputs['#sys.query#'] = query
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
query = prompt_template.format(
prompt_inputs
)
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)

View File

@@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
role_prefix: Optional[RolePrefix] = None
window: WindowConfig
query_prompt_template: Optional[str] = None

View File

@@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(

View File

@@ -110,19 +110,37 @@ class MilvusVector(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
from pymilvus import utility
if utility.has_collection(self._collection_name, using=alias):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] in {doc_ids}',
output_fields=["id"])
if result:
ids = [item["id"] for item in result]
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete(self) -> None:
alias = uuid4().hex

View File

@@ -50,7 +50,8 @@ class QdrantConfig(BaseModel):
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
'timeout': self.timeout,
'verify': self.endpoint.startswith('https')
}
@@ -217,29 +218,38 @@ class QdrantVector(BaseVector):
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
try:
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete(self):
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
filter = models.Filter(
must=[
@@ -257,29 +267,40 @@ class QdrantVector(BaseVector):
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
for node_id in ids:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
try:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
)
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
return
# Some other error occurred, so re-raise the exception
else:
raise e
def text_exists(self, id: str) -> bool:
all_collection_name = []

View File

@@ -121,18 +121,20 @@ class WeaviateVector(BaseVector):
return ids
def delete_by_metadata_field(self, key: str, value: str):
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
# check whether the index already exists
@@ -163,11 +165,14 @@ class WeaviateVector(BaseVector):
return True
def delete_by_ids(self, ids: list[str]) -> None:
for uuid in ids:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
# check whether the index already exists
schema = self._default_schema(self._collection_name)
if self._client.schema.contains(schema):
for uuid in ids:
self._client.data_object.delete(
class_name=self._collection_name,
uuid=uuid,
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""

View File

@@ -44,27 +44,31 @@ class BingSearchTool(BuiltinTool):
results = []
if search_results:
for result in search_results:
url = f': {result["url"]}' if "url" in result else ""
results.append(self.create_text_message(
text=f'{result["name"]}: {result["url"]}'
text=f'{result["name"]}{url}'
))
if entities:
for entity in entities:
url = f': {entity["url"]}' if "url" in entity else ""
results.append(self.create_text_message(
text=f'{entity["name"]}: {entity["url"]}'
text=f'{entity.get("name", "")}{url}'
))
if news:
for news_item in news:
url = f': {news_item["url"]}' if "url" in news_item else ""
results.append(self.create_text_message(
text=f'{news_item["name"]}: {news_item["url"]}'
text=f'{news_item.get("name", "")}{url}'
))
if related_searches:
for related in related_searches:
url = f': {related["displayText"]}' if "displayText" in related else ""
results.append(self.create_text_message(
text=f'{related["displayText"]}: {related["webSearchUrl"]}'
text=f'{related.get("displayText", "")}{url}'
))
return results
@@ -73,7 +77,7 @@ class BingSearchTool(BuiltinTool):
text = ''
if search_results:
for i, result in enumerate(search_results):
text += f'{i+1}: {result["name"]} - {result["snippet"]}\n'
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
if computation and 'expression' in computation and 'value' in computation:
text += '\nComputation:\n'
@@ -82,17 +86,20 @@ class BingSearchTool(BuiltinTool):
if entities:
text += '\nEntities:\n'
for entity in entities:
text += f'{entity["name"]} - {entity["url"]}\n'
url = f'- {entity["url"]}' if "url" in entity else ""
text += f'{entity.get("name", "")}{url}\n'
if news:
text += '\nNews:\n'
for news_item in news:
text += f'{news_item["name"]} - {news_item["url"]}\n'
url = f'- {news_item["url"]}' if "url" in news_item else ""
text += f'{news_item.get("name", "")}{url}\n'
if related_searches:
text += '\n\nRelated Searches:\n'
for related in related_searches:
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
text += f'{related.get("displayText", "")}{url}\n'
return self.create_text_message(text=self.summary(user_id=user_id, content=text))

View File

@@ -7,10 +7,10 @@ identity:
pt_BR: Interpretador de Código
description:
human:
en_US: Run code and get the result back, when you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
zh_Hans: 运行一段代码并返回结果当您使用较低质量的模型时请确保有一些提示帮助LLM理解如何编写代码。
pt_BR: Execute um trecho de código e obtenha o resultado de volta, quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
llm: A tool for running code and getting the result back, but only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
zh_Hans: 运行一段代码并返回结果当您使用较低质量的模型时请确保有一些提示帮助LLM理解如何编写代码。
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
parameters:
- name: language
type: string

View File

@@ -0,0 +1,21 @@
<?xml version="1.0" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN"
"http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd">
<svg version="1.0" xmlns="http://www.w3.org/2000/svg"
width="128.000000pt" height="128.000000pt" viewBox="0 0 128.000000 128.000000"
preserveAspectRatio="xMidYMid meet">
<g transform="translate(0.000000,128.000000) scale(0.100000,-0.100000)"
fill="#000000" stroke="none">
<path d="M0 975 l0 -305 33 1 c54 0 336 35 343 41 3 4 0 57 -7 118 -10 85 -17
113 -29 120 -47 25 -45 104 2 133 13 8 118 26 246 41 208 26 225 26 248 11 14
-9 30 -27 36 -41 10 -22 8 -33 -10 -68 l-23 -42 40 -316 40 -315 30 -31 c17
-17 31 -38 31 -47 0 -25 -27 -72 -46 -79 -35 -13 -450 -59 -476 -53 -52 13
-70 85 -32 127 10 13 10 33 -1 120 -8 58 -15 111 -15 118 0 16 -31 16 -237 -5
l-173 -17 0 -243 0 -243 640 0 640 0 0 640 0 640 -640 0 -640 0 0 -305z"/>
<path d="M578 977 c-128 -16 -168 -24 -168 -35 0 -10 8 -12 28 -8 15 3 90 12
167 21 167 18 188 23 180 35 -7 12 -1 12 -207 -13z"/>
<path d="M660 326 c-100 -13 -163 -25 -160 -31 3 -5 14 -9 25 -8 104 11 305
35 323 39 12 2 22 9 22 14 0 13 -14 12 -210 -14z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,23 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class Judge0CEProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
SubmitCodeExecutionTaskTool().fork_tool_runtime(
meta={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"source_code": "print('hello world')",
"language_id": 71,
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -0,0 +1,29 @@
identity:
author: Richards Tu
name: judge0ce
label:
en_US: Judge0 CE
zh_Hans: Judge0 CE
pt_BR: Judge0 CE
description:
en_US: Judge0 CE is an open-source code execution system. Support various languages, including C, C++, Java, Python, Ruby, etc.
zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。
pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc.
icon: icon.svg
credentials_for_provider:
X-RapidAPI-Key:
type: secret-input
required: true
label:
en_US: RapidAPI Key
zh_Hans: RapidAPI Key
pt_BR: RapidAPI Key
help:
en_US: RapidAPI Key is required to access the Judge0 CE API.
zh_Hans: RapidAPI Key 是访问 Judge0 CE API 所必需的。
pt_BR: RapidAPI Key é necessário para acessar a API do Judge0 CE.
placeholder:
en_US: Enter your RapidAPI Key
zh_Hans: 输入你的 RapidAPI Key
pt_BR: Insira sua RapidAPI Key
url: https://rapidapi.com/judge0-official/api/judge0-ce

View File

@@ -0,0 +1,37 @@
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GetExecutionResultTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key = self.runtime.credentials['X-RapidAPI-Key']
url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}"
headers = {
"X-RapidAPI-Key": api_key
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
result = response.json()
return self.create_text_message(text=f"Submission details:\n"
f"stdout: {result.get('stdout', '')}\n"
f"stderr: {result.get('stderr', '')}\n"
f"compile_output: {result.get('compile_output', '')}\n"
f"message: {result.get('message', '')}\n"
f"status: {result['status']['description']}\n"
f"time: {result.get('time', '')} seconds\n"
f"memory: {result.get('memory', '')} bytes")
else:
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")

View File

@@ -0,0 +1,23 @@
identity:
name: getExecutionResult
author: Richards Tu
label:
en_US: Get Execution Result
zh_Hans: 获取执行结果
description:
human:
en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。
llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
parameters:
- name: token
type: string
required: true
label:
en_US: Token
zh_Hans: 令牌
human_description:
en_US: The submission's unique token.
zh_Hans: 提交的唯一令牌。
llm_description: The submission's unique token. MUST get from submitCodeExecution.
form: llm

View File

@@ -0,0 +1,49 @@
import json
from typing import Any, Union
from httpx import post
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class SubmitCodeExecutionTaskTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key = self.runtime.credentials['X-RapidAPI-Key']
source_code = tool_parameters['source_code']
language_id = tool_parameters['language_id']
stdin = tool_parameters.get('stdin', '')
expected_output = tool_parameters.get('expected_output', '')
additional_files = tool_parameters.get('additional_files', '')
url = "https://judge0-ce.p.rapidapi.com/submissions"
querystring = {"base64_encoded": "false", "fields": "*"}
payload = {
"language_id": language_id,
"source_code": source_code,
"stdin": stdin,
"expected_output": expected_output,
"additional_files": additional_files,
}
headers = {
"content-type": "application/json",
"Content-Type": "application/json",
"X-RapidAPI-Key": api_key,
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
}
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
if response.status_code != 201:
raise Exception(response.text)
token = response.json()['token']
return self.create_text_message(text=token)

View File

@@ -0,0 +1,67 @@
identity:
name: submitCodeExecutionTask
author: Richards Tu
label:
en_US: Submit Code Execution Task
zh_Hans: 提交代码执行任务
description:
human:
en_US: A tool for submitting code execution task to Judge0 CE.
zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。
llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission.
parameters:
- name: source_code
type: string
required: true
label:
en_US: Source Code
zh_Hans: 源代码
human_description:
en_US: The source code to be executed.
zh_Hans: 要执行的源代码。
llm_description: The source code to be executed.
form: llm
- name: language_id
type: number
required: true
label:
en_US: Language ID
zh_Hans: 语言 ID
human_description:
en_US: The ID of the language in which the source code is written.
zh_Hans: 源代码所使用的语言的 ID。
llm_description: The ID of the language in which the source code is written. For example, 50 for C++, 71 for Python, etc.
form: llm
- name: stdin
type: string
required: false
label:
en_US: Standard Input
zh_Hans: 标准输入
human_description:
en_US: The standard input to be provided to the program.
zh_Hans: 提供给程序的标准输入。
llm_description: The standard input to be provided to the program. Optional.
form: llm
- name: expected_output
type: string
required: false
label:
en_US: Expected Output
zh_Hans: 期望输出
human_description:
en_US: The expected output of the program. Used for comparison in some scenarios.
zh_Hans: 程序的期望输出。在某些场景下用于比较。
llm_description: The expected output of the program. Used for comparison in some scenarios. Optional.
form: llm
- name: additional_files
type: string
required: false
label:
en_US: Additional Files
zh_Hans: 附加文件
human_description:
en_US: Base64 encoded additional files for the submission.
zh_Hans: 提交的 Base64 编码的附加文件。
llm_description: Base64 encoded additional files for the submission. Optional.
form: llm

View File

@@ -222,7 +222,7 @@ class ToolManager:
return parameter_value
@classmethod
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
"""
get the agent tool runtime
"""
@@ -245,6 +245,7 @@ class ToolManager:
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
identity_id=f'AGENT.{app_id}'
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
@@ -252,7 +253,7 @@ class ToolManager:
return tool_entity
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
"""
get the workflow tool runtime
"""
@@ -277,6 +278,7 @@ class ToolManager:
tool_runtime=tool_entity,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
identity_id=f'WORKFLOW.{app_id}.{node_id}'
)
if runtime_parameters:

View File

@@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
tool_runtime: Tool
provider_name: str
provider_type: str
identity_id: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
@@ -176,6 +177,8 @@ class ToolParameterConfigurationManager(BaseModel):
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
@@ -194,7 +197,8 @@ class ToolParameterConfigurationManager(BaseModel):
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cached_parameters = cache.get()
if cached_parameters:
@@ -223,7 +227,8 @@ class ToolParameterConfigurationManager(BaseModel):
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cache.delete()

View File

@@ -43,7 +43,8 @@ class SystemVariable(Enum):
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION = 'conversation'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
@classmethod
def value_of(cls, value: str) -> 'SystemVariable':

View File

@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
node_data=node_data,
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
files=files,
context=context,
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
inputs[variable_selector.variable] = variable_value
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
if variable_value is None:
raise ValueError(f'Variable {variable_selector.variable} not found')
inputs[variable_selector.variable] = variable_value
return inputs
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
return None
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data_model: node data model
@@ -385,7 +398,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
if conversation_id is None:
return None
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
def _fetch_prompt_messages(self, node_data: LLMNodeData,
query: Optional[str],
query_prompt_template: Optional[str],
inputs: dict[str, str],
files: list[FileVar],
context: Optional[str],
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
Fetch prompt messages
:param node_data: node data
:param query: query
:param query_prompt_template: query prompt template
:param inputs: inputs
:param files: files
:param context: context
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
context=context,
memory_config=node_data.memory,
memory=memory,
model_config=model_config
model_config=model_config,
query_prompt_template=query_prompt_template,
)
stop = model_config.stop
@@ -539,12 +555,22 @@ class LLMNode(BaseNode):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
.extract_variable_selectors())
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
return variable_mapping
@classmethod

View File

@@ -1,8 +1,6 @@
from typing import cast
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
@@ -19,17 +17,10 @@ class StartNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
cleaned_inputs = variable_pool.user_inputs
for var in variable_pool.system_variables:
if var == SystemVariable.CONVERSATION:
continue
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
return NodeRunResult(
@@ -38,42 +29,6 @@ class StartNode(BaseNode):
outputs=cleaned_inputs
)
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
for variable_config in variables:
variable = variable_config.variable
if variable not in user_inputs or not user_inputs[variable]:
if variable_config.required:
raise ValueError(f"Input form variable {variable} is required")
else:
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
continue
value = user_inputs[variable]
if value:
if not isinstance(value, str):
raise ValueError(f"{variable} in input form must be a string")
if variable_config.type == VariableEntity.Type.SELECT:
options = variable_config.options if variable_config.options is not None else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if variable_config.max_length is not None:
max_length = variable_config.max_length
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""

View File

@@ -39,7 +39,8 @@ class ToolNode(BaseNode):
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
self.app_id
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@@ -22,5 +22,6 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
)
manager.delete_tool_parameters_cache()

View File

@@ -6,6 +6,7 @@ from datetime import datetime, timedelta, timezone
from typing import Union
import boto3
import oss2 as aliyun_s3
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
from botocore.client import Config
from botocore.exceptions import ClientError
@@ -42,7 +43,14 @@ class Storage:
)
self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'),
credential=sas_token)
elif self.storage_type == 'aliyun-oss':
self.bucket_name = app.config.get('ALIYUN_OSS_BUCKET_NAME')
self.client = aliyun_s3.Bucket(
aliyun_s3.Auth(app.config.get('ALIYUN_OSS_ACCESS_KEY'), app.config.get('ALIYUN_OSS_SECRET_KEY')),
app.config.get('ALIYUN_OSS_ENDPOINT'),
self.bucket_name,
connect_timeout=30
)
else:
self.folder = app.config.get('STORAGE_LOCAL_PATH')
if not os.path.isabs(self.folder):
@@ -54,6 +62,8 @@ class Storage:
elif self.storage_type == 'azure-blob':
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
elif self.storage_type == 'aliyun-oss':
self.client.put_object(filename, data)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
@@ -86,6 +96,9 @@ class Storage:
blob = self.client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data = blob.download_blob().readall()
elif self.storage_type == 'aliyun-oss':
with closing(self.client.get_object(filename)) as obj:
data = obj.read()
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
@@ -118,6 +131,10 @@ class Storage:
with closing(blob.download_blob()) as blob_stream:
while chunk := blob_stream.readall(4096):
yield chunk
elif self.storage_type == 'aliyun-oss':
with closing(self.client.get_object(filename)) as obj:
while chunk := obj.read(4096):
yield chunk
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
@@ -142,6 +159,8 @@ class Storage:
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data.readinto(my_blob)
elif self.storage_type == 'aliyun-oss':
self.client.get_object_to_file(filename, target_filepath)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
@@ -164,6 +183,8 @@ class Storage:
elif self.storage_type == 'azure-blob':
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()
elif self.storage_type == 'aliyun-oss':
return self.client.object_exists(filename)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
@@ -178,6 +199,8 @@ class Storage:
elif self.storage_type == 'azure-blob':
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)
elif self.storage_type == 'aliyun-oss':
self.client.delete_object(filename)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename

View File

@@ -62,6 +62,12 @@ model_config_partial_fields = {
'pre_prompt': fields.String,
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
@@ -70,9 +76,11 @@ app_partial_fields = {
'icon': fields.String,
'icon_background': fields.String,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
'created_at': TimestampField
'created_at': TimestampField,
'tags': fields.List(fields.Nested(tag_fields))
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),

View File

@@ -27,6 +27,11 @@ dataset_retrieval_model_fields = {
'score_threshold': fields.Float
}
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String
}
dataset_detail_fields = {
'id': fields.String,
@@ -46,7 +51,8 @@ dataset_detail_fields = {
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean,
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
'tags': fields.List(fields.Nested(tag_fields))
}
dataset_query_detail_fields = {

8
api/fields/tag_fields.py Normal file
View File

@@ -0,0 +1,8 @@
from flask_restful import fields
tag_fields = {
'id': fields.String,
'name': fields.String,
'type': fields.String,
'binding_count': fields.String
}

View File

@@ -0,0 +1,62 @@
"""add-tags-and-binding-table
Revision ID: 3c7cac9521c6
Revises: c3311b089690
Create Date: 2024-04-11 06:17:34.278594
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '3c7cac9521c6'
down_revision = 'c3311b089690'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tag_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
sa.Column('tag_id', postgresql.UUID(), nullable=True),
sa.Column('target_id', postgresql.UUID(), nullable=True),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tag_binding_pkey')
)
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
batch_op.create_index('tag_bind_tag_id_idx', ['tag_id'], unique=False)
batch_op.create_index('tag_bind_target_id_idx', ['target_id'], unique=False)
op.create_table('tags',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
sa.Column('type', sa.String(length=16), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tag_pkey')
)
with op.batch_alter_table('tags', schema=None) as batch_op:
batch_op.create_index('tag_name_idx', ['name'], unique=False)
batch_op.create_index('tag_type_idx', ['type'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tags', schema=None) as batch_op:
batch_op.drop_index('tag_type_idx')
batch_op.drop_index('tag_name_idx')
op.drop_table('tags')
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
batch_op.drop_index('tag_bind_target_id_idx')
batch_op.drop_index('tag_bind_tag_id_idx')
op.drop_table('tag_bindings')
# ### end Alembic commands ###

View File

@@ -100,10 +100,11 @@ class Account(UserMixin, db.Model):
return db.session.query(ai).filter(
ai.account_id == self.id
).all()
# check current_user.current_tenant.current_role in ['admin', 'owner']
@property
def is_admin_or_owner(self):
return self._current_tenant.current_role in ['admin', 'owner']
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
class TenantStatus(str, enum.Enum):
@@ -111,6 +112,16 @@ class TenantStatus(str, enum.Enum):
ARCHIVE = 'archive'
class TenantAccountRole(str, enum.Enum):
OWNER = 'owner'
ADMIN = 'admin'
NORMAL = 'normal'
@staticmethod
def is_privileged_role(role: str) -> bool:
return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.OWNER}
class Tenant(db.Model):
__tablename__ = 'tenants'
__table_args__ = (
@@ -132,11 +143,11 @@ class Tenant(db.Model):
Account.id == TenantAccountJoin.account_id,
TenantAccountJoin.tenant_id == self.id
).all()
@property
def custom_config_dict(self) -> dict:
return json.loads(self.custom_config) if self.custom_config else {}
@custom_config_dict.setter
def custom_config_dict(self, value: dict):
self.custom_config = json.dumps(value)

View File

@@ -9,7 +9,7 @@ from sqlalchemy.dialects.postgresql import JSONB, UUID
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.model import App, UploadFile
from models.model import App, Tag, TagBinding, UploadFile
class Dataset(db.Model):
@@ -118,6 +118,20 @@ class Dataset(db.Model):
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
@property
def tags(self):
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
Tag.type == 'knowledge'
).all()
return tags if tags else []
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")

View File

@@ -148,7 +148,7 @@ class App(db.Model):
return []
agent_mode = app_model_config.agent_mode_dict
tools = agent_mode.get('tools', [])
provider_ids = []
for tool in tools:
@@ -185,6 +185,20 @@ class App(db.Model):
return deleted_tools
@property
def tags(self):
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
Tag.type == 'app'
).all()
return tags if tags else []
class AppModelConfig(db.Model):
__tablename__ = 'app_model_configs'
@@ -292,7 +306,8 @@ class AppModelConfig(db.Model):
@property
def agent_mode_dict(self) -> dict:
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [],
"prompt": None}
@property
def chat_prompt_config_dict(self) -> dict:
@@ -463,6 +478,7 @@ class InstalledApp(db.Model):
return tenant
class Conversation(db.Model):
__tablename__ = 'conversations'
__table_args__ = (
@@ -1175,11 +1191,11 @@ class MessageAgentThought(db.Model):
return json.loads(self.message_files)
else:
return []
@property
def tools(self) -> list[str]:
return self.tool.split(";") if self.tool else []
@property
def tool_labels(self) -> dict:
try:
@@ -1189,7 +1205,7 @@ class MessageAgentThought(db.Model):
return {}
except Exception as e:
return {}
@property
def tool_meta(self) -> dict:
try:
@@ -1199,7 +1215,7 @@ class MessageAgentThought(db.Model):
return {}
except Exception as e:
return {}
@property
def tool_inputs_dict(self) -> dict:
tools = self.tools
@@ -1222,7 +1238,7 @@ class MessageAgentThought(db.Model):
}
except Exception as e:
return {}
@property
def tool_outputs_dict(self) -> dict:
tools = self.tools
@@ -1249,6 +1265,7 @@ class MessageAgentThought(db.Model):
tool: self.observation for tool in tools
}
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'
__table_args__ = (
@@ -1274,3 +1291,37 @@ class DatasetRetrieverResource(db.Model):
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(db.Model):
__tablename__ = 'tags'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tag_pkey'),
db.Index('tag_type_idx', 'type'),
db.Index('tag_name_idx', 'name'),
)
TAG_TYPE_LIST = ['knowledge', 'app']
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
type = db.Column(db.String(16), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class TagBinding(db.Model):
__tablename__ = 'tag_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='tag_binding_pkey'),
db.Index('tag_bind_target_id_idx', 'target_id'),
db.Index('tag_bind_tag_id_idx', 'tag_id'),
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=True)
tag_id = db.Column(UUID, nullable=True)
target_id = db.Column(UUID, nullable=True)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@@ -55,7 +55,7 @@ xinference-client==0.9.4
safetensors~=0.4.3
zhipuai==1.0.7
werkzeug~=3.0.1
pymilvus~=2.3.0
pymilvus==2.3.1
qdrant-client==1.7.3
cohere~=5.2.4
pyyaml~=6.0.1
@@ -67,7 +67,7 @@ httpx[socks]~=0.24.1
matplotlib~=3.8.2
yfinance~=0.2.35
pydub~=0.25.1
gmpy2~=2.2.0a1
gmpy2~=2.1.5
numexpr~=2.9.0
duckduckgo-search==5.2.2
arxiv==2.1.0
@@ -80,3 +80,4 @@ lxml==5.1.0
xlrd~=2.0.1
pydantic~=1.10.0
pgvecto-rs==0.1.4
oss2==2.15.0

View File

@@ -344,9 +344,9 @@ class TenantService:
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
"""Check member permission"""
perms = {
'add': ['owner', 'admin'],
'remove': ['owner'],
'update': ['owner']
'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
'remove': [TenantAccountRole.OWNER],
'update': [TenantAccountRole.OWNER]
}
if action not in ['add', 'remove', 'update']:
raise InvalidActionError("Invalid action.")

View File

@@ -82,19 +82,22 @@ class AgentService:
tool_output = tool_outputs.get(tool_name, {})
tool_meta_data = tool_meta.get(tool_name, {})
tool_config = tool_meta_data.get('tool_config', {})
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_config.get('tool_provider_type', ''),
provider_id=tool_config.get('tool_provider', ''),
)
if not tool_icon:
tool_entity = find_agent_tool(tool_name)
if tool_entity:
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,
)
if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_config.get('tool_provider_type', ''),
provider_id=tool_config.get('tool_provider', ''),
)
if not tool_icon:
tool_entity = find_agent_tool(tool_name)
if tool_entity:
tool_icon = ToolManager.get_tool_icon(
tenant_id=app_model.tenant_id,
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,
)
else:
tool_icon = ''
tool_calls.append({
'status': 'success' if not tool_meta_data.get('error') else 'error',

View File

@@ -5,23 +5,28 @@ from typing import cast
import yaml
from flask import current_app
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from constants.model_template import default_app_templates
from core.agent.entities import AgentToolEntity
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode, AppModelConfig
from models.tools import ApiToolProvider
from services.tag_service import TagService
from services.workflow_service import WorkflowService
class AppService:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
"""
Get app list with pagination
:param tenant_id: tenant id
@@ -45,6 +50,14 @@ class AppService:
if 'name' in args and args['name']:
name = args['name'][:30]
filters.append(App.name.ilike(f'%{name}%'))
if 'tag_ids' in args and args['tag_ids']:
target_ids = TagService.get_target_ids_by_tag_ids('app',
tenant_id,
args['tag_ids'])
if target_ids:
filters.append(App.id.in_(target_ids))
else:
return None
app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),
@@ -240,6 +253,64 @@ class AppService:
return yaml.dump(export_data)
def get_app(self, app: App) -> App:
"""
Get App
"""
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app.id}'
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
class ModifiedApp(App):
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app
def update_app(self, app: App, args: dict) -> App:
"""
Update app

View File

@@ -3,7 +3,7 @@ import os
import requests
from extensions.ext_database import db
from models.account import TenantAccountJoin
from models.account import TenantAccountJoin, TenantAccountRole
class BillingService:
@@ -74,5 +74,5 @@ class BillingService:
TenantAccountJoin.account_id == current_user.id
).first()
if join.role not in ['owner', 'admin']:
if TenantAccountRole.is_privileged_role(join.role):
raise ValueError('Only team owner or team admin can perform this action')

View File

@@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
class DatasetService:
@staticmethod
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
if user:
permission_filter = db.or_(Dataset.created_by == user.id,
Dataset.permission == 'all_team_members')
else:
permission_filter = Dataset.permission == 'all_team_members'
datasets = Dataset.query.filter(
query = Dataset.query.filter(
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
.order_by(Dataset.created_at.desc()) \
.paginate(
.order_by(Dataset.created_at.desc())
if search:
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
if target_ids:
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
else:
return [], 0
datasets = query.paginate(
page=page,
per_page=per_page,
max_per_page=100,
@@ -165,9 +176,36 @@ class DatasetService:
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_default_model_instance(
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_EMBEDDING
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider,
embedding_model.model
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
data['embedding_model'] != dataset.embedding_model:
action = 'update'
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=data['embedding_model_provider'],
model_type=ModelType.TEXT_EMBEDDING,
model=data['embedding_model']
)
filtered_data['embedding_model'] = embedding_model.model
filtered_data['embedding_model_provider'] = embedding_model.provider
@@ -376,6 +414,15 @@ class DocumentService:
return documents
@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = db.session.query(Document).filter(
Document.dataset_id == dataset_id,
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
).all()
return documents
@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = db.session.query(Document).filter(
@@ -440,6 +487,20 @@ class DocumentService:
# trigger async task
recover_document_indexing_task.delay(document.dataset_id, document.id)
@staticmethod
def retry_document(dataset_id: str, documents: list[Document]):
for document in documents:
# retry document indexing
document.indexing_status = 'waiting'
db.session.add(document)
db.session.commit()
# add retry flag
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
@staticmethod
def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
@@ -537,6 +598,7 @@ class DocumentService:
db.session.commit()
position = DocumentService.get_documents_position(dataset.id)
document_ids = []
duplicate_document_ids = []
if document_data["data_source"]["type"] == "upload_file":
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
for file_id in upload_file_list:
@@ -553,6 +615,28 @@ class DocumentService:
data_source_info = {
"upload_file_id": file_id,
}
# check duplicate
if document_data.get('duplicate', False):
document = Document.query.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type='upload_file',
enabled=True,
name=file_name
).first()
if document:
document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = datetime.datetime.utcnow()
document.created_from = created_from
document.doc_form = document_data['doc_form']
document.doc_language = document_data['doc_language']
document.data_source_info = json.dumps(data_source_info)
document.batch = batch
document.indexing_status = 'waiting'
db.session.add(document)
documents.append(document)
duplicate_document_ids.append(document.id)
continue
document = DocumentService.build_document(dataset, dataset_process_rule.id,
document_data["data_source"]["type"],
document_data["doc_form"],
@@ -618,7 +702,10 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_task.delay(dataset.id, document_ids)
if document_ids:
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
return documents, batch
@@ -626,7 +713,8 @@ class DocumentService:
def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
if count > can_upload_size:
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
raise ValueError(
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
@staticmethod
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
@@ -752,7 +840,6 @@ class DocumentService:
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
return document
@staticmethod

161
api/services/tag_service.py Normal file
View File

@@ -0,0 +1,161 @@
import uuid
from flask_login import current_user
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
query = db.session.query(
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
).outerjoin(
TagBinding, Tag.id == TagBinding.tag_id
).filter(
Tag.type == tag_type,
Tag.tenant_id == current_tenant_id
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
query = query.group_by(
Tag.id
)
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = db.session.query(Tag).filter(
Tag.id.in_(tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = db.session.query(
TagBinding.target_id
).filter(
TagBinding.tag_id.in_(tag_ids),
TagBinding.tenant_id == current_tenant_id
).all()
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
return tags if tags else []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args['name'],
type=args['type'],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id
)
db.session.add(tag)
db.session.commit()
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args['name']
db.session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# save tag binding
for tag_id in args['tag_ids']:
tag_binding = db.session.query(TagBinding).filter(
TagBinding.tag_id == tag_id,
TagBinding.target_id == args['target_id']
).first()
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args['target_id'],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id
)
db.session.add(new_tag_binding)
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(
TagBinding.target_id == args['target_id'],
TagBinding.tag_id == (args['tag_id'])
).first()
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == 'knowledge':
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == current_user.current_tenant_id,
Dataset.id == target_id
).first()
if not dataset:
raise NotFound("Dataset not found")
elif type == 'app':
app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.id == target_id
).first()
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")

View File

@@ -16,6 +16,7 @@ from models.dataset import (
)
# Add import statement for ValueError
@shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str, doc_form: str):
@@ -48,6 +49,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green'))
else:
logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green'))
# Specify the index type before initializing the index processor
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None)

View File

@@ -64,6 +64,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
elif action == 'update':
# clean index
index_processor.clean(dataset, None, with_keywords=False)
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
# add new index
if dataset_documents:
documents = []
for dataset_document in dataset_documents:
# delete from vector index
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True
).order_by(DocumentSegment.position.asc()).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
end_at = time.perf_counter()
logging.info(

View File

@@ -0,0 +1,94 @@
import datetime
import logging
import time
import click
from celery import shared_task
from flask import current_app
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
"""
Async process document
:param dataset_id:
:param document_ids:
Usage: duplicate_document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
for document_id in document_ids:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
except DocumentIsPausedException as ex:
logging.info(click.style(str(ex), fg='yellow'))
except Exception:
pass

View File

@@ -0,0 +1,91 @@
import datetime
import logging
import time
import click
from celery import shared_task
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@shared_task(queue='dataset')
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
"""
Async process document
:param dataset_id:
:param document_ids:
Usage: retry_document_indexing_task.delay(dataset_id, document_id)
"""
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
for document_id in document_ids:
retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id)
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
"your subscription.")
except Exception as e:
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
if document:
document.indexing_status = 'error'
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logging.info(click.style('Start retry document: {}'.format(document_id), fg='green'))
document = db.session.query(Document).filter(
Document.id == document_id,
Document.dataset_id == dataset_id
).first()
try:
if document:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = 'parsing'
document.processing_started_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = 'error'
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg='yellow'))
redis_client.delete(retry_indexing_cache_key)
pass
end_at = time.perf_counter()
logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))

View File

@@ -0,0 +1,38 @@
import uuid
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
setup_mock_redis,
)
def test_milvus_vector(setup_mock_redis) -> None:
dataset_id = str(uuid.uuid4())
vector = MilvusVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
config=MilvusConfig(
host='localhost',
port=19530,
user='root',
password='Milvus',
)
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
# milvus dos not support full text searching yet in < 2.3.x
# delete vector
vector.delete()

View File

@@ -0,0 +1,40 @@
import uuid
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
get_sample_text,
setup_mock_redis,
)
def test_qdrant_vector(setup_mock_redis)-> None:
dataset_id = str(uuid.uuid4())
vector = QdrantVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
group_id=dataset_id,
config=QdrantConfig(
endpoint='http://localhost:6333',
api_key='difyai123456',
)
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
# search by full text
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
assert len(hits_by_full_text) >= 1
# delete vector
vector.delete()

View File

@@ -0,0 +1,46 @@
from unittest.mock import MagicMock
import pytest
from core.rag.models.document import Document
from extensions import ext_redis
def get_sample_text() -> str:
return 'test_text'
def get_sample_embedding() -> list[float]:
return [1.1, 2.2, 3.3]
def get_sample_query_vector() -> list[float]:
return get_sample_embedding()
def get_sample_document(sample_dataset_id: str) -> Document:
doc = Document(
page_content=get_sample_text(),
metadata={
"doc_id": sample_dataset_id,
"doc_hash": sample_dataset_id,
"document_id": sample_dataset_id,
"dataset_id": sample_dataset_id,
}
)
return doc
@pytest.fixture
def setup_mock_redis() -> None:
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock

View File

@@ -0,0 +1,41 @@
import uuid
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
from models.dataset import Dataset
from tests.integration_tests.vdb.test_vector_store import (
get_sample_document,
get_sample_embedding,
get_sample_query_vector,
get_sample_text,
setup_mock_redis,
)
def test_weaviate_vector(setup_mock_redis) -> None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
dataset_id = str(uuid.uuid4())
vector = WeaviateVector(
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
config=WeaviateConfig(
endpoint='http://localhost:8080',
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
),
attributes=attributes
)
# create vector
vector.create(
texts=[get_sample_document(dataset_id)],
embeddings=[get_sample_embedding()],
)
# search by vector
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
assert len(hits_by_vector) >= 1
# search by full text
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
assert len(hits_by_full_text) >= 1
# delete vector
vector.delete()

View File

@@ -65,7 +65,8 @@ def test_execute_llm(setup_openai_mock):
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION: 'abababa'
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')

View File

@@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages():
prompt_rules = prompt_template['prompt_rules']
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
max_token_limit=2000,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)}
real_prompt = prompt_template['prompt_template'].format(full_inputs)

View File

@@ -18,7 +18,7 @@ def test_default_value():
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
config = MilvusConfig(**valid_config)
assert config.secure is False
assert config.database == 'default'

View File

@@ -28,6 +28,7 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')

View File

@@ -118,6 +118,7 @@ def test_execute_if_else_result_true():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
@@ -179,6 +180,7 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def'])

View File

@@ -0,0 +1,29 @@
import rsa as pyrsa
from Crypto.PublicKey import RSA
from libs import gmpy2_pkcs10aep_cipher
def test_gmpy2_pkcs10aep_cipher() -> None:
rsa_key_pair = pyrsa.newkeys(2048)
public_key = rsa_key_pair[0].save_pkcs1()
private_key = rsa_key_pair[1].save_pkcs1()
public_rsa_key = RSA.import_key(public_key)
public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key)
private_rsa_key = RSA.import_key(private_key)
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
raw_text = 'raw_text'
raw_text_bytes = raw_text.encode()
# RSA encryption by public key and decryption by private key
encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes)
decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key)
assert decrypted_by_pub_key == raw_text_bytes
# RSA encryption and decryption by private key
encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes)
decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key)
assert decrypted_by_private_key == raw_text_bytes

View File

@@ -0,0 +1,12 @@
from models.account import TenantAccountRole
def test_account_is_privileged_role() -> None:
assert TenantAccountRole.ADMIN == 'admin'
assert TenantAccountRole.OWNER == 'owner'
assert TenantAccountRole.NORMAL == 'normal'
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
assert not TenantAccountRole.is_privileged_role('')

View File

@@ -89,7 +89,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
)
]
nodes = workflow_converter._convert_to_http_request_node(
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables
@@ -159,7 +159,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
)
]
nodes = workflow_converter._convert_to_http_request_node(
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables

View File

@@ -9,3 +9,6 @@ dev/pytest/pytest_tools.sh
# Workflow
dev/pytest/pytest_workflow.sh
# Unit tests
dev/pytest/pytest_unit_tests.sh

View File

@@ -0,0 +1,5 @@
#!/bin/bash
set -x
# libs
pytest api/tests/unit_tests

4
dev/pytest/pytest_vdb.sh Executable file
View File

@@ -0,0 +1,4 @@
#!/bin/bash
set -x
pytest api/tests/integration_tests/vdb/

View File

@@ -36,7 +36,7 @@ services:
timeout: 20s
retries: 3
standalone:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
command: ["milvus", "run", "standalone"]

View File

@@ -0,0 +1,12 @@
version: '3'
services:
# Qdrant vector store.
qdrant:
image: langgenius/qdrant:v1.7.3
restart: always
volumes:
- ./volumes/qdrant:/qdrant/storage
environment:
QDRANT_API_KEY: 'difyai123456'
ports:
- "6333:6333"

View File

@@ -2,7 +2,7 @@ version: '3'
services:
# API service
api:
image: langgenius/dify-api:0.6.4
image: langgenius/dify-api:0.6.5
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -150,7 +150,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.6.4
image: langgenius/dify-api:0.6.5
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -238,7 +238,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.6.4
image: langgenius/dify-web:0.6.5
restart: always
environment:
EDITION: SELF_HOSTED

Some files were not shown because too many files have changed in this diff Show More