mirror of
https://github.com/langgenius/dify.git
synced 2026-01-10 00:04:14 +00:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34bfb715e1 | ||
|
|
019d7069f8 | ||
|
|
c54fcfb45d | ||
|
|
cde87cb225 | ||
|
|
12435774ca | ||
|
|
80b9507e7a | ||
|
|
0ac0f0ffd0 | ||
|
|
3d14aba4b4 | ||
|
|
64f694865c | ||
|
|
d36b728088 | ||
|
|
1a7b4c42ab | ||
|
|
2a64ce740e | ||
|
|
78988ed60e | ||
|
|
2832adda88 | ||
|
|
a4e4fb4094 | ||
|
|
777ec64635 | ||
|
|
9cec8c1750 | ||
|
|
8ca5aa1190 | ||
|
|
4d8f1b9ca4 | ||
|
|
3da179f77b | ||
|
|
a34e8cb0bd | ||
|
|
b249767c5c | ||
|
|
89a7434565 | ||
|
|
3b537cbdeb | ||
|
|
731464f5b8 | ||
|
|
1ad70f8721 | ||
|
|
2ea8c73cd8 | ||
|
|
f257f2c396 | ||
|
|
3cd8e6f5c6 | ||
|
|
0715db7681 | ||
|
|
a39de8a686 | ||
|
|
ccaf335466 | ||
|
|
40e36e9b52 | ||
|
|
9eebe9d54e | ||
|
|
a23a191615 | ||
|
|
7d9c5586f9 | ||
|
|
f07c89bba4 | ||
|
|
59cba930e5 | ||
|
|
39ae56e136 | ||
|
|
f92130338b | ||
|
|
2867d29021 | ||
|
|
f76ac8bdee | ||
|
|
83caffe000 | ||
|
|
96160837d2 | ||
|
|
3480f1c59e | ||
|
|
65ac4f69af | ||
|
|
2c50fab3dd | ||
|
|
9525ccac4f | ||
|
|
ff76c4bd5d | ||
|
|
5dacf77627 | ||
|
|
2a213c6af7 | ||
|
|
b2535e7db6 | ||
|
|
28236147ee | ||
|
|
4969783383 |
32
.github/workflows/api-tests.yml
vendored
32
.github/workflows/api-tests.yml
vendored
@@ -10,7 +10,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
@@ -35,10 +37,26 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install APT packages
|
||||
uses: awalsh128/cache-apt-pkgs-action@v1
|
||||
- name: Set up Weaviate
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
packages: ffmpeg
|
||||
compose-file: docker/docker-compose.middleware.yaml
|
||||
services: weaviate
|
||||
|
||||
- name: Set up Qdrant
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.qdrant.yaml
|
||||
services: qdrant
|
||||
|
||||
- name: Set up Milvus
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.milvus.yaml
|
||||
services: |
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
@@ -52,6 +70,9 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt
|
||||
|
||||
- name: Run Unit tests
|
||||
run: dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
@@ -60,3 +81,6 @@ jobs:
|
||||
|
||||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
14
README.md
14
README.md
@@ -29,12 +29,12 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
#
|
||||
@@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto.
|
||||
|
||||
|
||||
**2. Soporte de modelos completo**:
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama2 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles**:
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama2, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -55,9 +55,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
|
||||
|
||||
**2. 網羅的なモデルサポート**:
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama2、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs
|
||||
|
||||
.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
|
||||

|
||||
|
||||
@@ -155,9 +153,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
|
||||
- **エンタープライズ/組織向けのDify</br>**
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティ
|
||||
|
||||
ングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
> AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -52,6 +52,11 @@ AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
@@ -160,3 +165,6 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
# API Tool configuration
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
@@ -104,7 +104,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.4"
|
||||
self.CURRENT_VERSION = "0.6.5"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -208,6 +208,10 @@ class Config:
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
|
||||
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
|
||||
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
|
||||
@@ -53,5 +53,8 @@ from .explore import (
|
||||
workflow,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
@@ -20,6 +19,7 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
@@ -29,21 +29,29 @@ class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_pagination_fields)
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
|
||||
return app_pagination
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -108,43 +116,9 @@ class AppApi(Resource):
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
# get original app model config
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
model_config: AppModelConfig = app_model.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
app_service = AppService()
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
db.session.commit()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
@@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
@@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
@@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||
if masked_key in agent_tool_entity.tool_parameters and \
|
||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
|
||||
@@ -48,11 +48,14 @@ class DatasetListApi(Resource):
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
@@ -184,6 +187,10 @@ class DatasetApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -506,10 +513,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request
|
||||
@@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=False, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
@@ -883,6 +884,49 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('document_ids', type=list, required=True, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
for document_id in args['document_ids']:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == 'completed':
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@@ -908,3 +952,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
|
||||
159
api/controllers/console/tag/tags.py
Normal file
159
api/controllers/console/tag/tags.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 50 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class TagListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tag_fields)
|
||||
def get(self):
|
||||
tag_type = request.args.get('type', type=str)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': 0
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': binding_count
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingCreateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
||||
help='Tag IDs is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
||||
help='Tag ID is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, '/tags')
|
||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
||||
@@ -9,7 +9,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
@@ -43,7 +43,7 @@ class MemberInviteEmailApi(Resource):
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
interface_language = args['language']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
if invitee_role not in [TenantAccountRole.ADMIN, TenantAccountRole.NORMAL]:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
inviter = current_user
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@@ -94,7 +95,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@@ -125,7 +126,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
|
||||
@@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
@@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
|
||||
@@ -23,20 +23,28 @@ class BaseAppGenerator:
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
if value and isinstance(value, str):
|
||||
filtered_inputs[variable] = value.replace('\x00', '')
|
||||
else:
|
||||
filtered_inputs[variable] = value if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,6 +36,14 @@ class WorkflowAppRunner:
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
@@ -67,7 +75,8 @@ class WorkflowAppRunner:
|
||||
else UserFrom.END_USER,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@@ -71,9 +71,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
@@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, str]
|
||||
inputs: dict[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
|
||||
@@ -118,7 +118,8 @@ class MessageCycleManage:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """
|
||||
import jinja2
|
||||
from json import loads
|
||||
from base64 import b64decode
|
||||
|
||||
template = jinja2.Template('''{{code}}''')
|
||||
|
||||
@@ -12,7 +15,8 @@ def main(**inputs):
|
||||
return template.render(**inputs)
|
||||
|
||||
# execute main function, and return the result
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**loads(inputs))
|
||||
|
||||
result = f'''<<RESULT>>{output}<<RESULT>>'''
|
||||
|
||||
@@ -39,6 +43,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
|
||||
|
||||
JINJA2_PRELOAD = f"""
|
||||
import jinja2
|
||||
from base64 import b64decode
|
||||
|
||||
def _jinja2_preload_():
|
||||
# prepare jinja2 environment, load template and render before to avoid sandbox issue
|
||||
@@ -60,9 +65,11 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
:return:
|
||||
"""
|
||||
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# transform jinja2 template to python code
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """# declare main function here
|
||||
{{code}}
|
||||
|
||||
from json import loads, dumps
|
||||
from base64 import b64decode
|
||||
|
||||
# execute main function, and return the result
|
||||
# inputs is a dict, and it
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**json.loads(inputs))
|
||||
|
||||
# convert output to json and print
|
||||
output = json.dumps(output, indent=4)
|
||||
output = dumps(output, indent=4)
|
||||
|
||||
result = f'''<<RESULT>>
|
||||
{output}
|
||||
@@ -54,7 +59,7 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# replace code and inputs
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
|
||||
@@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType,
|
||||
identity_id: str
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@@ -10,3 +10,6 @@
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
|
||||
@@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "amazon":
|
||||
human_prompt_prefix = "\n\nUser:"
|
||||
@@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
|
||||
if model_parameters.get("countPenalty"):
|
||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
payload["temperature"] = model_parameters.get("temperature")
|
||||
payload["top_p"] = model_parameters.get("top_p")
|
||||
payload["max_tokens"] = model_parameters.get("max_tokens")
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["stop"] = stop[:10] if stop else []
|
||||
|
||||
elif model_prefix == "anthropic":
|
||||
payload = { **model_parameters }
|
||||
@@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
output = response_body.get("generation").strip('\n')
|
||||
prompt_tokens = response_body.get("prompt_token_count")
|
||||
completion_tokens = response_body.get("generation_token_count")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
output = response_body.get("outputs")[0].get("text")
|
||||
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
|
||||
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
@@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
content_delta = payload.get("text")
|
||||
finish_reason = payload.get("finish_reason")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
content_delta = payload.get('outputs')[0].get("text")
|
||||
finish_reason = payload.get('outputs')[0].get("stop_reason")
|
||||
|
||||
elif model_prefix == "meta":
|
||||
content_delta = payload.get("generation").strip('\n')
|
||||
finish_reason = payload.get("stop_reason")
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
model: mistral.mistral-7b-instruct-v0:2
|
||||
label:
|
||||
en_US: Mistral 7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00015'
|
||||
output: '0.0002'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,27 @@
|
||||
model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.7
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,39 @@
|
||||
model: mistral.mixtral-8x7b-instruct-v0:1
|
||||
label:
|
||||
en_US: Mixtral 8X7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00045'
|
||||
output: '0.0007'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@@ -19,7 +19,7 @@ class GroqProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-70b-4096',
|
||||
model='llama3-8b-8192',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
model: llama3-70b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-70B-8192
|
||||
en_US: Llama-3-70B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,25 @@
|
||||
model: llama3-8b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-8B-8192
|
||||
en_US: Llama-3-8B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.59'
|
||||
output: '0.79'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a audio model')
|
||||
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
@@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
|
||||
|
||||
response = xinference_client.transcriptions(
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language = language,
|
||||
prompt = prompt,
|
||||
|
||||
@@ -31,7 +31,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
@@ -51,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query_prompt_template=query_prompt_template,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
@@ -119,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
@@ -146,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if query and query_prompt_template:
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=query_prompt_template,
|
||||
with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt_inputs['#sys.query#'] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
query = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
|
||||
@@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
|
||||
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: Optional[str] = None
|
||||
|
||||
@@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
|
||||
@@ -110,19 +110,37 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
alias = uuid4().hex
|
||||
|
||||
@@ -50,7 +50,8 @@ class QdrantConfig(BaseModel):
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
'timeout': self.timeout,
|
||||
'verify': self.endpoint.startswith('https')
|
||||
}
|
||||
|
||||
|
||||
@@ -217,29 +218,38 @@ class QdrantVector(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
@@ -257,29 +267,40 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
|
||||
@@ -121,18 +121,20 @@ class WeaviateVector(BaseVector):
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@@ -163,11 +165,14 @@ class WeaviateVector(BaseVector):
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
|
||||
@@ -44,27 +44,31 @@ class BingSearchTool(BuiltinTool):
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{result["name"]}: {result["url"]}'
|
||||
text=f'{result["name"]}{url}'
|
||||
))
|
||||
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{entity["name"]}: {entity["url"]}'
|
||||
text=f'{entity.get("name", "")}{url}'
|
||||
))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{news_item["name"]}: {news_item["url"]}'
|
||||
text=f'{news_item.get("name", "")}{url}'
|
||||
))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{related["displayText"]}: {related["webSearchUrl"]}'
|
||||
text=f'{related.get("displayText", "")}{url}'
|
||||
))
|
||||
|
||||
return results
|
||||
@@ -73,7 +77,7 @@ class BingSearchTool(BuiltinTool):
|
||||
text = ''
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i+1}: {result["name"]} - {result["snippet"]}\n'
|
||||
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and 'expression' in computation and 'value' in computation:
|
||||
text += '\nComputation:\n'
|
||||
@@ -82,17 +86,20 @@ class BingSearchTool(BuiltinTool):
|
||||
if entities:
|
||||
text += '\nEntities:\n'
|
||||
for entity in entities:
|
||||
text += f'{entity["name"]} - {entity["url"]}\n'
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += '\nNews:\n'
|
||||
for news_item in news:
|
||||
text += f'{news_item["name"]} - {news_item["url"]}\n'
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += '\n\nRelated Searches:\n'
|
||||
for related in related_searches:
|
||||
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ identity:
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
human:
|
||||
en_US: Run code and get the result back, when you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||
zh_Hans: 运行一段代码并返回结果,当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta, quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||
llm: A tool for running code and getting the result back, but only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||
parameters:
|
||||
- name: language
|
||||
type: string
|
||||
|
||||
21
api/core/tools/provider/builtin/judge0ce/_assets/icon.svg
Normal file
21
api/core/tools/provider/builtin/judge0ce/_assets/icon.svg
Normal file
@@ -0,0 +1,21 @@
|
||||
<?xml version="1.0" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN"
|
||||
"http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd">
|
||||
<svg version="1.0" xmlns="http://www.w3.org/2000/svg"
|
||||
width="128.000000pt" height="128.000000pt" viewBox="0 0 128.000000 128.000000"
|
||||
preserveAspectRatio="xMidYMid meet">
|
||||
|
||||
<g transform="translate(0.000000,128.000000) scale(0.100000,-0.100000)"
|
||||
fill="#000000" stroke="none">
|
||||
<path d="M0 975 l0 -305 33 1 c54 0 336 35 343 41 3 4 0 57 -7 118 -10 85 -17
|
||||
113 -29 120 -47 25 -45 104 2 133 13 8 118 26 246 41 208 26 225 26 248 11 14
|
||||
-9 30 -27 36 -41 10 -22 8 -33 -10 -68 l-23 -42 40 -316 40 -315 30 -31 c17
|
||||
-17 31 -38 31 -47 0 -25 -27 -72 -46 -79 -35 -13 -450 -59 -476 -53 -52 13
|
||||
-70 85 -32 127 10 13 10 33 -1 120 -8 58 -15 111 -15 118 0 16 -31 16 -237 -5
|
||||
l-173 -17 0 -243 0 -243 640 0 640 0 0 640 0 640 -640 0 -640 0 0 -305z"/>
|
||||
<path d="M578 977 c-128 -16 -168 -24 -168 -35 0 -10 8 -12 28 -8 15 3 90 12
|
||||
167 21 167 18 188 23 180 35 -7 12 -1 12 -207 -13z"/>
|
||||
<path d="M660 326 c-100 -13 -163 -25 -160 -31 3 -5 14 -9 25 -8 104 11 305
|
||||
35 323 39 12 2 22 9 22 14 0 13 -14 12 -210 -14z"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
23
api/core/tools/provider/builtin/judge0ce/judge0ce.py
Normal file
23
api/core/tools/provider/builtin/judge0ce/judge0ce.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class Judge0CEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
SubmitCodeExecutionTaskTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"source_code": "print('hello world')",
|
||||
"language_id": 71,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
29
api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
Normal file
29
api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Richards Tu
|
||||
name: judge0ce
|
||||
label:
|
||||
en_US: Judge0 CE
|
||||
zh_Hans: Judge0 CE
|
||||
pt_BR: Judge0 CE
|
||||
description:
|
||||
en_US: Judge0 CE is an open-source code execution system. Support various languages, including C, C++, Java, Python, Ruby, etc.
|
||||
zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。
|
||||
pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc.
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
X-RapidAPI-Key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: RapidAPI Key
|
||||
zh_Hans: RapidAPI Key
|
||||
pt_BR: RapidAPI Key
|
||||
help:
|
||||
en_US: RapidAPI Key is required to access the Judge0 CE API.
|
||||
zh_Hans: RapidAPI Key 是访问 Judge0 CE API 所必需的。
|
||||
pt_BR: RapidAPI Key é necessário para acessar a API do Judge0 CE.
|
||||
placeholder:
|
||||
en_US: Enter your RapidAPI Key
|
||||
zh_Hans: 输入你的 RapidAPI Key
|
||||
pt_BR: Insira sua RapidAPI Key
|
||||
url: https://rapidapi.com/judge0-official/api/judge0-ce
|
||||
@@ -0,0 +1,37 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GetExecutionResultTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}"
|
||||
headers = {
|
||||
"X-RapidAPI-Key": api_key
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return self.create_text_message(text=f"Submission details:\n"
|
||||
f"stdout: {result.get('stdout', '')}\n"
|
||||
f"stderr: {result.get('stderr', '')}\n"
|
||||
f"compile_output: {result.get('compile_output', '')}\n"
|
||||
f"message: {result.get('message', '')}\n"
|
||||
f"status: {result['status']['description']}\n"
|
||||
f"time: {result.get('time', '')} seconds\n"
|
||||
f"memory: {result.get('memory', '')} bytes")
|
||||
else:
|
||||
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")
|
||||
@@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: getExecutionResult
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Get Execution Result
|
||||
zh_Hans: 获取执行结果
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。
|
||||
llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
parameters:
|
||||
- name: token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌
|
||||
human_description:
|
||||
en_US: The submission's unique token.
|
||||
zh_Hans: 提交的唯一令牌。
|
||||
llm_description: The submission's unique token. MUST get from submitCodeExecution.
|
||||
form: llm
|
||||
@@ -0,0 +1,49 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SubmitCodeExecutionTaskTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
source_code = tool_parameters['source_code']
|
||||
language_id = tool_parameters['language_id']
|
||||
stdin = tool_parameters.get('stdin', '')
|
||||
expected_output = tool_parameters.get('expected_output', '')
|
||||
additional_files = tool_parameters.get('additional_files', '')
|
||||
|
||||
url = "https://judge0-ce.p.rapidapi.com/submissions"
|
||||
|
||||
querystring = {"base64_encoded": "false", "fields": "*"}
|
||||
|
||||
payload = {
|
||||
"language_id": language_id,
|
||||
"source_code": source_code,
|
||||
"stdin": stdin,
|
||||
"expected_output": expected_output,
|
||||
"additional_files": additional_files,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"X-RapidAPI-Key": api_key,
|
||||
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
|
||||
}
|
||||
|
||||
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise Exception(response.text)
|
||||
|
||||
token = response.json()['token']
|
||||
|
||||
return self.create_text_message(text=token)
|
||||
@@ -0,0 +1,67 @@
|
||||
identity:
|
||||
name: submitCodeExecutionTask
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Submit Code Execution Task
|
||||
zh_Hans: 提交代码执行任务
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for submitting code execution task to Judge0 CE.
|
||||
zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。
|
||||
llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission.
|
||||
parameters:
|
||||
- name: source_code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Source Code
|
||||
zh_Hans: 源代码
|
||||
human_description:
|
||||
en_US: The source code to be executed.
|
||||
zh_Hans: 要执行的源代码。
|
||||
llm_description: The source code to be executed.
|
||||
form: llm
|
||||
- name: language_id
|
||||
type: number
|
||||
required: true
|
||||
label:
|
||||
en_US: Language ID
|
||||
zh_Hans: 语言 ID
|
||||
human_description:
|
||||
en_US: The ID of the language in which the source code is written.
|
||||
zh_Hans: 源代码所使用的语言的 ID。
|
||||
llm_description: The ID of the language in which the source code is written. For example, 50 for C++, 71 for Python, etc.
|
||||
form: llm
|
||||
- name: stdin
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Standard Input
|
||||
zh_Hans: 标准输入
|
||||
human_description:
|
||||
en_US: The standard input to be provided to the program.
|
||||
zh_Hans: 提供给程序的标准输入。
|
||||
llm_description: The standard input to be provided to the program. Optional.
|
||||
form: llm
|
||||
- name: expected_output
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Expected Output
|
||||
zh_Hans: 期望输出
|
||||
human_description:
|
||||
en_US: The expected output of the program. Used for comparison in some scenarios.
|
||||
zh_Hans: 程序的期望输出。在某些场景下用于比较。
|
||||
llm_description: The expected output of the program. Used for comparison in some scenarios. Optional.
|
||||
form: llm
|
||||
- name: additional_files
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Additional Files
|
||||
zh_Hans: 附加文件
|
||||
human_description:
|
||||
en_US: Base64 encoded additional files for the submission.
|
||||
zh_Hans: 提交的 Base64 编码的附加文件。
|
||||
llm_description: Base64 encoded additional files for the submission. Optional.
|
||||
form: llm
|
||||
@@ -222,7 +222,7 @@ class ToolManager:
|
||||
return parameter_value
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
@@ -245,6 +245,7 @@ class ToolManager:
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
identity_id=f'AGENT.{app_id}'
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
@@ -252,7 +253,7 @@ class ToolManager:
|
||||
return tool_entity
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
|
||||
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
@@ -277,6 +278,7 @@ class ToolManager:
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=workflow_tool.provider_id,
|
||||
provider_type=workflow_tool.provider_type,
|
||||
identity_id=f'WORKFLOW.{app_id}.{node_id}'
|
||||
)
|
||||
|
||||
if runtime_parameters:
|
||||
|
||||
@@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
identity_id: str
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return {key: value for key, value in parameters.items()}
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
@@ -176,6 +177,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
@@ -194,7 +197,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
@@ -223,7 +227,8 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ class SystemVariable(Enum):
|
||||
"""
|
||||
QUERY = 'query'
|
||||
FILES = 'files'
|
||||
CONVERSATION = 'conversation'
|
||||
CONVERSATION_ID = 'conversation_id'
|
||||
USER_ID = 'user_id'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'SystemVariable':
|
||||
|
||||
@@ -74,6 +74,7 @@ class LLMNode(BaseNode):
|
||||
node_data=node_data,
|
||||
query=variable_pool.get_variable_value(['sys', SystemVariable.QUERY.value])
|
||||
if node_data.memory else None,
|
||||
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
context=context,
|
||||
@@ -209,6 +210,17 @@ class LLMNode(BaseNode):
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
|
||||
if variable_value is None:
|
||||
raise ValueError(f'Variable {variable_selector.variable} not found')
|
||||
|
||||
inputs[variable_selector.variable] = variable_value
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
@@ -302,7 +314,8 @@ class LLMNode(BaseNode):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[
|
||||
ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config
|
||||
:param node_data_model: node data model
|
||||
@@ -385,7 +398,7 @@ class LLMNode(BaseNode):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
|
||||
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION_ID.value])
|
||||
if conversation_id is None:
|
||||
return None
|
||||
|
||||
@@ -407,6 +420,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
@@ -417,6 +431,7 @@ class LLMNode(BaseNode):
|
||||
Fetch prompt messages
|
||||
:param node_data: node data
|
||||
:param query: query
|
||||
:param query_prompt_template: query prompt template
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param context: context
|
||||
@@ -433,7 +448,8 @@ class LLMNode(BaseNode):
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
query_prompt_template=query_prompt_template,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
@@ -539,12 +555,22 @@ class LLMNode(BaseNode):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template)
|
||||
.extract_variable_selectors())
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping['#context#'] = node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
|
||||
|
||||
if node_data.memory:
|
||||
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
@@ -19,17 +17,10 @@ class StartNode(BaseNode):
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
variables = node_data.variables
|
||||
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
|
||||
cleaned_inputs = variable_pool.user_inputs
|
||||
|
||||
for var in variable_pool.system_variables:
|
||||
if var == SystemVariable.CONVERSATION:
|
||||
continue
|
||||
|
||||
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
|
||||
|
||||
return NodeRunResult(
|
||||
@@ -38,42 +29,6 @@ class StartNode(BaseNode):
|
||||
outputs=cleaned_inputs
|
||||
)
|
||||
|
||||
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if variable_config.required:
|
||||
raise ValueError(f"Input form variable {variable} is required")
|
||||
else:
|
||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
||||
@@ -39,7 +39,8 @@ class ToolNode(BaseNode):
|
||||
parameters = self._generate_parameters(variable_pool, node_data)
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
|
||||
self.app_id
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
||||
@@ -22,5 +22,6 @@ def handle(sender, **kwargs):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=tool_entity.provider_name,
|
||||
provider_type=tool_entity.provider_type,
|
||||
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Union
|
||||
|
||||
import boto3
|
||||
import oss2 as aliyun_s3
|
||||
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
|
||||
from botocore.client import Config
|
||||
from botocore.exceptions import ClientError
|
||||
@@ -42,7 +43,14 @@ class Storage:
|
||||
)
|
||||
self.client = BlobServiceClient(account_url=app.config.get('AZURE_BLOB_ACCOUNT_URL'),
|
||||
credential=sas_token)
|
||||
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.bucket_name = app.config.get('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.client = aliyun_s3.Bucket(
|
||||
aliyun_s3.Auth(app.config.get('ALIYUN_OSS_ACCESS_KEY'), app.config.get('ALIYUN_OSS_SECRET_KEY')),
|
||||
app.config.get('ALIYUN_OSS_ENDPOINT'),
|
||||
self.bucket_name,
|
||||
connect_timeout=30
|
||||
)
|
||||
else:
|
||||
self.folder = app.config.get('STORAGE_LOCAL_PATH')
|
||||
if not os.path.isabs(self.folder):
|
||||
@@ -54,6 +62,8 @@ class Storage:
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.upload_blob(filename, data)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.put_object(filename, data)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
@@ -86,6 +96,9 @@ class Storage:
|
||||
blob = self.client.get_container_client(container=self.bucket_name)
|
||||
blob = blob.get_blob_client(blob=filename)
|
||||
data = blob.download_blob().readall()
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
data = obj.read()
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
@@ -118,6 +131,10 @@ class Storage:
|
||||
with closing(blob.download_blob()) as blob_stream:
|
||||
while chunk := blob_stream.readall(4096):
|
||||
yield chunk
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
with closing(self.client.get_object(filename)) as obj:
|
||||
while chunk := obj.read(4096):
|
||||
yield chunk
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
@@ -142,6 +159,8 @@ class Storage:
|
||||
with open(target_filepath, "wb") as my_blob:
|
||||
blob_data = blob.download_blob()
|
||||
blob_data.readinto(my_blob)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.get_object_to_file(filename, target_filepath)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
@@ -164,6 +183,8 @@ class Storage:
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob = self.client.get_blob_client(container=self.bucket_name, blob=filename)
|
||||
return blob.exists()
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
return self.client.object_exists(filename)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
@@ -178,6 +199,8 @@ class Storage:
|
||||
elif self.storage_type == 'azure-blob':
|
||||
blob_container = self.client.get_container_client(container=self.bucket_name)
|
||||
blob_container.delete_blob(filename)
|
||||
elif self.storage_type == 'aliyun-oss':
|
||||
self.client.delete_object(filename)
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
|
||||
@@ -62,6 +62,12 @@ model_config_partial_fields = {
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
@@ -70,9 +76,11 @@ app_partial_fields = {
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'tags': fields.List(fields.Nested(tag_fields))
|
||||
}
|
||||
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
|
||||
@@ -27,6 +27,11 @@ dataset_retrieval_model_fields = {
|
||||
'score_threshold': fields.Float
|
||||
}
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
@@ -46,7 +51,8 @@ dataset_detail_fields = {
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean,
|
||||
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields)
|
||||
'retrieval_model_dict': fields.Nested(dataset_retrieval_model_fields),
|
||||
'tags': fields.List(fields.Nested(tag_fields))
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
|
||||
8
api/fields/tag_fields.py
Normal file
8
api/fields/tag_fields.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from flask_restful import fields
|
||||
|
||||
tag_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'type': fields.String,
|
||||
'binding_count': fields.String
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
"""add-tags-and-binding-table
|
||||
|
||||
Revision ID: 3c7cac9521c6
|
||||
Revises: c3311b089690
|
||||
Create Date: 2024-04-11 06:17:34.278594
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '3c7cac9521c6'
|
||||
down_revision = 'c3311b089690'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tag_bindings',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('tag_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('target_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tag_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('tag_bind_tag_id_idx', ['tag_id'], unique=False)
|
||||
batch_op.create_index('tag_bind_target_id_idx', ['target_id'], unique=False)
|
||||
|
||||
op.create_table('tags',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('type', sa.String(length=16), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tag_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tags', schema=None) as batch_op:
|
||||
batch_op.create_index('tag_name_idx', ['name'], unique=False)
|
||||
batch_op.create_index('tag_type_idx', ['type'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tags', schema=None) as batch_op:
|
||||
batch_op.drop_index('tag_type_idx')
|
||||
batch_op.drop_index('tag_name_idx')
|
||||
|
||||
op.drop_table('tags')
|
||||
with op.batch_alter_table('tag_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('tag_bind_target_id_idx')
|
||||
batch_op.drop_index('tag_bind_tag_id_idx')
|
||||
|
||||
op.drop_table('tag_bindings')
|
||||
# ### end Alembic commands ###
|
||||
@@ -100,10 +100,11 @@ class Account(UserMixin, db.Model):
|
||||
return db.session.query(ai).filter(
|
||||
ai.account_id == self.id
|
||||
).all()
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
@property
|
||||
def is_admin_or_owner(self):
|
||||
return self._current_tenant.current_role in ['admin', 'owner']
|
||||
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
|
||||
|
||||
|
||||
class TenantStatus(str, enum.Enum):
|
||||
@@ -111,6 +112,16 @@ class TenantStatus(str, enum.Enum):
|
||||
ARCHIVE = 'archive'
|
||||
|
||||
|
||||
class TenantAccountRole(str, enum.Enum):
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
NORMAL = 'normal'
|
||||
|
||||
@staticmethod
|
||||
def is_privileged_role(role: str) -> bool:
|
||||
return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.OWNER}
|
||||
|
||||
|
||||
class Tenant(db.Model):
|
||||
__tablename__ = 'tenants'
|
||||
__table_args__ = (
|
||||
@@ -132,11 +143,11 @@ class Tenant(db.Model):
|
||||
Account.id == TenantAccountJoin.account_id,
|
||||
TenantAccountJoin.tenant_id == self.id
|
||||
).all()
|
||||
|
||||
|
||||
@property
|
||||
def custom_config_dict(self) -> dict:
|
||||
return json.loads(self.custom_config) if self.custom_config else {}
|
||||
|
||||
|
||||
@custom_config_dict.setter
|
||||
def custom_config_dict(self, value: dict):
|
||||
self.custom_config = json.dumps(value)
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.account import Account
|
||||
from models.model import App, UploadFile
|
||||
from models.model import App, Tag, TagBinding, UploadFile
|
||||
|
||||
|
||||
class Dataset(db.Model):
|
||||
@@ -118,6 +118,20 @@ class Dataset(db.Model):
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == self.id,
|
||||
TagBinding.tenant_id == self.tenant_id,
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == 'knowledge'
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
@staticmethod
|
||||
def gen_collection_name_by_id(dataset_id: str) -> str:
|
||||
normalized_dataset_id = dataset_id.replace("-", "_")
|
||||
|
||||
@@ -148,7 +148,7 @@ class App(db.Model):
|
||||
return []
|
||||
agent_mode = app_model_config.agent_mode_dict
|
||||
tools = agent_mode.get('tools', [])
|
||||
|
||||
|
||||
provider_ids = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -185,6 +185,20 @@ class App(db.Model):
|
||||
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == self.id,
|
||||
TagBinding.tenant_id == self.tenant_id,
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == 'app'
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
class AppModelConfig(db.Model):
|
||||
__tablename__ = 'app_model_configs'
|
||||
@@ -292,7 +306,8 @@ class AppModelConfig(db.Model):
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict:
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [],
|
||||
"prompt": None}
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict:
|
||||
@@ -463,6 +478,7 @@ class InstalledApp(db.Model):
|
||||
return tenant
|
||||
|
||||
|
||||
|
||||
class Conversation(db.Model):
|
||||
__tablename__ = 'conversations'
|
||||
__table_args__ = (
|
||||
@@ -1175,11 +1191,11 @@ class MessageAgentThought(db.Model):
|
||||
return json.loads(self.message_files)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
@property
|
||||
def tools(self) -> list[str]:
|
||||
return self.tool.split(";") if self.tool else []
|
||||
|
||||
|
||||
@property
|
||||
def tool_labels(self) -> dict:
|
||||
try:
|
||||
@@ -1189,7 +1205,7 @@ class MessageAgentThought(db.Model):
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_meta(self) -> dict:
|
||||
try:
|
||||
@@ -1199,7 +1215,7 @@ class MessageAgentThought(db.Model):
|
||||
return {}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_inputs_dict(self) -> dict:
|
||||
tools = self.tools
|
||||
@@ -1222,7 +1238,7 @@ class MessageAgentThought(db.Model):
|
||||
}
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
|
||||
@property
|
||||
def tool_outputs_dict(self) -> dict:
|
||||
tools = self.tools
|
||||
@@ -1249,6 +1265,7 @@ class MessageAgentThought(db.Model):
|
||||
tool: self.observation for tool in tools
|
||||
}
|
||||
|
||||
|
||||
class DatasetRetrieverResource(db.Model):
|
||||
__tablename__ = 'dataset_retriever_resources'
|
||||
__table_args__ = (
|
||||
@@ -1274,3 +1291,37 @@ class DatasetRetrieverResource(db.Model):
|
||||
retriever_from = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
class Tag(db.Model):
|
||||
__tablename__ = 'tags'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tag_pkey'),
|
||||
db.Index('tag_type_idx', 'type'),
|
||||
db.Index('tag_name_idx', 'name'),
|
||||
)
|
||||
|
||||
TAG_TYPE_LIST = ['knowledge', 'app']
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
type = db.Column(db.String(16), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class TagBinding(db.Model):
|
||||
__tablename__ = 'tag_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='tag_binding_pkey'),
|
||||
db.Index('tag_bind_target_id_idx', 'target_id'),
|
||||
db.Index('tag_bind_tag_id_idx', 'tag_id'),
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(UUID, nullable=True)
|
||||
tag_id = db.Column(UUID, nullable=True)
|
||||
target_id = db.Column(UUID, nullable=True)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
@@ -55,7 +55,7 @@ xinference-client==0.9.4
|
||||
safetensors~=0.4.3
|
||||
zhipuai==1.0.7
|
||||
werkzeug~=3.0.1
|
||||
pymilvus~=2.3.0
|
||||
pymilvus==2.3.1
|
||||
qdrant-client==1.7.3
|
||||
cohere~=5.2.4
|
||||
pyyaml~=6.0.1
|
||||
@@ -67,7 +67,7 @@ httpx[socks]~=0.24.1
|
||||
matplotlib~=3.8.2
|
||||
yfinance~=0.2.35
|
||||
pydub~=0.25.1
|
||||
gmpy2~=2.2.0a1
|
||||
gmpy2~=2.1.5
|
||||
numexpr~=2.9.0
|
||||
duckduckgo-search==5.2.2
|
||||
arxiv==2.1.0
|
||||
@@ -80,3 +80,4 @@ lxml==5.1.0
|
||||
xlrd~=2.0.1
|
||||
pydantic~=1.10.0
|
||||
pgvecto-rs==0.1.4
|
||||
oss2==2.15.0
|
||||
@@ -344,9 +344,9 @@ class TenantService:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
'add': ['owner', 'admin'],
|
||||
'remove': ['owner'],
|
||||
'update': ['owner']
|
||||
'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
'remove': [TenantAccountRole.OWNER],
|
||||
'update': [TenantAccountRole.OWNER]
|
||||
}
|
||||
if action not in ['add', 'remove', 'update']:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
|
||||
@@ -82,19 +82,22 @@ class AgentService:
|
||||
tool_output = tool_outputs.get(tool_name, {})
|
||||
tool_meta_data = tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_config.get('tool_provider_type', ''),
|
||||
provider_id=tool_config.get('tool_provider', ''),
|
||||
)
|
||||
if not tool_icon:
|
||||
tool_entity = find_agent_tool(tool_name)
|
||||
if tool_entity:
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
)
|
||||
if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_config.get('tool_provider_type', ''),
|
||||
provider_id=tool_config.get('tool_provider', ''),
|
||||
)
|
||||
if not tool_icon:
|
||||
tool_entity = find_agent_tool(tool_name)
|
||||
if tool_entity:
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_entity.provider_type,
|
||||
provider_id=tool_entity.provider_id,
|
||||
)
|
||||
else:
|
||||
tool_icon = ''
|
||||
|
||||
tool_calls.append({
|
||||
'status': 'success' if not tool_meta_data.get('error') else 'error',
|
||||
|
||||
@@ -5,23 +5,28 @@ from typing import cast
|
||||
|
||||
import yaml
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tag_service import TagService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination:
|
||||
def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param tenant_id: tenant id
|
||||
@@ -45,6 +50,14 @@ class AppService:
|
||||
if 'name' in args and args['name']:
|
||||
name = args['name'][:30]
|
||||
filters.append(App.name.ilike(f'%{name}%'))
|
||||
if 'tag_ids' in args and args['tag_ids']:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('app',
|
||||
tenant_id,
|
||||
args['tag_ids'])
|
||||
if target_ids:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
return None
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
@@ -240,6 +253,64 @@ class AppService:
|
||||
|
||||
return yaml.dump(export_data)
|
||||
|
||||
def get_app(self, app: App) -> App:
|
||||
"""
|
||||
Get App
|
||||
"""
|
||||
# get original app model config
|
||||
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app.id}'
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
class ModifiedApp(App):
|
||||
"""
|
||||
Modified App class
|
||||
"""
|
||||
def __init__(self, app):
|
||||
self.__dict__.update(app.__dict__)
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
return model_config
|
||||
|
||||
app = ModifiedApp(app)
|
||||
|
||||
return app
|
||||
|
||||
def update_app(self, app: App, args: dict) -> App:
|
||||
"""
|
||||
Update app
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import requests
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantAccountJoin
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class BillingService:
|
||||
@@ -74,5 +74,5 @@ class BillingService:
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
|
||||
if join.role not in ['owner', 'admin']:
|
||||
if TenantAccountRole.is_privileged_role(join.role):
|
||||
raise ValueError('Only team owner or team admin can perform this action')
|
||||
|
||||
@@ -38,28 +38,39 @@ from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
from tasks.clean_notion_document_task import clean_notion_document_task
|
||||
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
|
||||
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):
|
||||
def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None):
|
||||
if user:
|
||||
permission_filter = db.or_(Dataset.created_by == user.id,
|
||||
Dataset.permission == 'all_team_members')
|
||||
else:
|
||||
permission_filter = Dataset.permission == 'all_team_members'
|
||||
datasets = Dataset.query.filter(
|
||||
query = Dataset.query.filter(
|
||||
db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \
|
||||
.order_by(Dataset.created_at.desc()) \
|
||||
.paginate(
|
||||
.order_by(Dataset.created_at.desc())
|
||||
if search:
|
||||
query = query.filter(db.and_(Dataset.name.ilike(f'%{search}%')))
|
||||
if tag_ids:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids)
|
||||
if target_ids:
|
||||
query = query.filter(db.and_(Dataset.id.in_(target_ids)))
|
||||
else:
|
||||
return [], 0
|
||||
datasets = query.paginate(
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
max_per_page=100,
|
||||
@@ -165,9 +176,36 @@ class DatasetService:
|
||||
# get embedding model setting
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING
|
||||
provider=data['embedding_model_provider'],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data['embedding_model']
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.model
|
||||
filtered_data['embedding_model_provider'] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider,
|
||||
embedding_model.model
|
||||
)
|
||||
filtered_data['collection_binding_id'] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
else:
|
||||
if data['embedding_model_provider'] != dataset.embedding_model_provider or \
|
||||
data['embedding_model'] != dataset.embedding_model:
|
||||
action = 'update'
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=data['embedding_model_provider'],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data['embedding_model']
|
||||
)
|
||||
filtered_data['embedding_model'] = embedding_model.model
|
||||
filtered_data['embedding_model_provider'] = embedding_model.provider
|
||||
@@ -376,6 +414,15 @@ class DocumentService:
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.indexing_status == 'error' or Document.indexing_status == 'paused'
|
||||
).all()
|
||||
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
|
||||
documents = db.session.query(Document).filter(
|
||||
@@ -440,6 +487,20 @@ class DocumentService:
|
||||
# trigger async task
|
||||
recover_document_indexing_task.delay(document.dataset_id, document.id)
|
||||
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
for document in documents:
|
||||
# retry document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
@@ -537,6 +598,7 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
for file_id in upload_file_list:
|
||||
@@ -553,6 +615,28 @@ class DocumentService:
|
||||
data_source_info = {
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
# check duplicate
|
||||
if document_data.get('duplicate', False):
|
||||
document = Document.query.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type='upload_file',
|
||||
enabled=True,
|
||||
name=file_name
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.updated_at = datetime.datetime.utcnow()
|
||||
document.created_from = created_from
|
||||
document.doc_form = document_data['doc_form']
|
||||
document.doc_language = document_data['doc_language']
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.batch = batch
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
duplicate_document_ids.append(document.id)
|
||||
continue
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
@@ -618,7 +702,10 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if document_ids:
|
||||
document_indexing_task.delay(dataset.id, document_ids)
|
||||
if duplicate_document_ids:
|
||||
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
|
||||
|
||||
return documents, batch
|
||||
|
||||
@@ -626,7 +713,8 @@ class DocumentService:
|
||||
def check_documents_upload_quota(count: int, features: FeatureModel):
|
||||
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
|
||||
if count > can_upload_size:
|
||||
raise ValueError(f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
||||
raise ValueError(
|
||||
f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.')
|
||||
|
||||
@staticmethod
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
@@ -752,7 +840,6 @@ class DocumentService:
|
||||
db.session.commit()
|
||||
# trigger async task
|
||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
|
||||
161
api/services/tag_service.py
Normal file
161
api/services/tag_service.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import func
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, Tag, TagBinding
|
||||
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
|
||||
query = db.session.query(
|
||||
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
|
||||
).outerjoin(
|
||||
TagBinding, Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
Tag.type == tag_type,
|
||||
Tag.tenant_id == current_tenant_id
|
||||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
|
||||
query = query.group_by(
|
||||
Tag.id
|
||||
)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
tags = db.session.query(Tag).filter(
|
||||
Tag.id.in_(tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
tag_bindings = db.session.query(
|
||||
TagBinding.target_id
|
||||
).filter(
|
||||
TagBinding.tag_id.in_(tag_ids),
|
||||
TagBinding.tenant_id == current_tenant_id
|
||||
).all()
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args['name'],
|
||||
type=args['type'],
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def update_tags(args: dict, tag_id: str) -> Tag:
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args['name']
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@staticmethod
|
||||
def get_tag_binding_count(tag_id: str) -> int:
|
||||
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def delete_tag(tag_id: str):
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
db.session.delete(tag)
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
|
||||
if tag_bindings:
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.delete(tag_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def save_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
# save tag binding
|
||||
for tag_id in args['tag_ids']:
|
||||
tag_binding = db.session.query(TagBinding).filter(
|
||||
TagBinding.tag_id == tag_id,
|
||||
TagBinding.target_id == args['target_id']
|
||||
).first()
|
||||
if tag_binding:
|
||||
continue
|
||||
new_tag_binding = TagBinding(
|
||||
tag_id=tag_id,
|
||||
target_id=args['target_id'],
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(
|
||||
TagBinding.target_id == args['target_id'],
|
||||
TagBinding.tag_id == (args['tag_id'])
|
||||
).first()
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
if type == 'knowledge':
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == current_user.current_tenant_id,
|
||||
Dataset.id == target_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == 'app':
|
||||
app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.id == target_id
|
||||
).first()
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
else:
|
||||
raise NotFound("Invalid binding type")
|
||||
|
||||
@@ -16,6 +16,7 @@ from models.dataset import (
|
||||
)
|
||||
|
||||
|
||||
# Add import statement for ValueError
|
||||
@shared_task(queue='dataset')
|
||||
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
index_struct: str, collection_binding_id: str, doc_form: str):
|
||||
@@ -48,6 +49,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green'))
|
||||
else:
|
||||
logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green'))
|
||||
# Specify the index type before initializing the index processor
|
||||
if doc_form is None:
|
||||
raise ValueError("Index type must be specified.")
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None)
|
||||
|
||||
|
||||
@@ -64,6 +64,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
elif action == 'update':
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
# delete from vector index
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True
|
||||
).order_by(DocumentSegment.position.asc()).all()
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
||||
94
api/tasks/duplicate_document_indexing_task.py
Normal file
94
api/tasks/duplicate_document_indexing_task.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: duplicate_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logging.info(click.style('Start process document: {}'.format(document_id), fg='green'))
|
||||
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
|
||||
if document:
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
||||
except DocumentIsPausedException as ex:
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
except Exception:
|
||||
pass
|
||||
91
api/tasks/retry_document_indexing_task.py
Normal file
91
api/tasks/retry_document_indexing_task.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id)
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
return
|
||||
|
||||
logging.info(click.style('Start retry document: {}'.format(document_id), fg='green'))
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
try:
|
||||
if document:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
pass
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
|
||||
0
api/tests/integration_tests/vdb/milvus/__init__.py
Normal file
0
api/tests/integration_tests/vdb/milvus/__init__.py
Normal file
38
api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
38
api/tests/integration_tests/vdb/milvus/test_milvus.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_vector(setup_mock_redis) -> None:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = MilvusVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
config=MilvusConfig(
|
||||
host='localhost',
|
||||
port=19530,
|
||||
user='root',
|
||||
password='Milvus',
|
||||
)
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
|
||||
# milvus dos not support full text searching yet in < 2.3.x
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
0
api/tests/integration_tests/vdb/qdrant/__init__.py
Normal file
0
api/tests/integration_tests/vdb/qdrant/__init__.py
Normal file
40
api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
40
api/tests/integration_tests/vdb/qdrant/test_qdrant.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
get_sample_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_qdrant_vector(setup_mock_redis)-> None:
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = QdrantVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
group_id=dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint='http://localhost:6333',
|
||||
api_key='difyai123456',
|
||||
)
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
|
||||
# search by full text
|
||||
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
|
||||
assert len(hits_by_full_text) >= 1
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
46
api/tests/integration_tests/vdb/test_vector_store.py
Normal file
46
api/tests/integration_tests/vdb/test_vector_store.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
def get_sample_text() -> str:
|
||||
return 'test_text'
|
||||
|
||||
|
||||
def get_sample_embedding() -> list[float]:
|
||||
return [1.1, 2.2, 3.3]
|
||||
|
||||
|
||||
def get_sample_query_vector() -> list[float]:
|
||||
return get_sample_embedding()
|
||||
|
||||
|
||||
def get_sample_document(sample_dataset_id: str) -> Document:
|
||||
doc = Document(
|
||||
page_content=get_sample_text(),
|
||||
metadata={
|
||||
"doc_id": sample_dataset_id,
|
||||
"doc_hash": sample_dataset_id,
|
||||
"document_id": sample_dataset_id,
|
||||
"dataset_id": sample_dataset_id,
|
||||
}
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
41
api/tests/integration_tests/vdb/weaviate/test_weaviate.py
Normal file
41
api/tests/integration_tests/vdb/weaviate/test_weaviate.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
|
||||
from models.dataset import Dataset
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
get_sample_document,
|
||||
get_sample_embedding,
|
||||
get_sample_query_vector,
|
||||
get_sample_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
def test_weaviate_vector(setup_mock_redis) -> None:
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
dataset_id = str(uuid.uuid4())
|
||||
vector = WeaviateVector(
|
||||
collection_name=Dataset.gen_collection_name_by_id(dataset_id),
|
||||
config=WeaviateConfig(
|
||||
endpoint='http://localhost:8080',
|
||||
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
|
||||
),
|
||||
attributes=attributes
|
||||
)
|
||||
|
||||
# create vector
|
||||
vector.create(
|
||||
texts=[get_sample_document(dataset_id)],
|
||||
embeddings=[get_sample_embedding()],
|
||||
)
|
||||
|
||||
# search by vector
|
||||
hits_by_vector = vector.search_by_vector(query_vector=get_sample_query_vector())
|
||||
assert len(hits_by_vector) >= 1
|
||||
|
||||
# search by full text
|
||||
hits_by_full_text = vector.search_by_full_text(query=get_sample_text())
|
||||
assert len(hits_by_full_text) >= 1
|
||||
|
||||
# delete vector
|
||||
vector.delete()
|
||||
@@ -65,7 +65,8 @@ def test_execute_llm(setup_openai_mock):
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.QUERY: 'what\'s the weather today?',
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.CONVERSATION: 'abababa'
|
||||
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
|
||||
|
||||
|
||||
@@ -238,8 +238,8 @@ def test__get_completion_model_prompt_messages():
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
)}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_default_value():
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig(**config)
|
||||
assert e.value.errors()[1]['msg'] == f'config MILVUS_{key.upper()} is required'
|
||||
|
||||
|
||||
config = MilvusConfig(**valid_config)
|
||||
assert config.secure is False
|
||||
assert config.database == 'default'
|
||||
|
||||
@@ -28,6 +28,7 @@ def test_execute_answer():
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
|
||||
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
|
||||
|
||||
@@ -118,6 +118,7 @@ def test_execute_if_else_result_true():
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
||||
@@ -179,6 +180,7 @@ def test_execute_if_else_result_false():
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['1ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ab', 'def'])
|
||||
|
||||
29
api/tests/unit_tests/libs/test_rsa.py
Normal file
29
api/tests/unit_tests/libs/test_rsa.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import rsa as pyrsa
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def test_gmpy2_pkcs10aep_cipher() -> None:
|
||||
rsa_key_pair = pyrsa.newkeys(2048)
|
||||
public_key = rsa_key_pair[0].save_pkcs1()
|
||||
private_key = rsa_key_pair[1].save_pkcs1()
|
||||
|
||||
public_rsa_key = RSA.import_key(public_key)
|
||||
public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key)
|
||||
|
||||
private_rsa_key = RSA.import_key(private_key)
|
||||
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
|
||||
|
||||
raw_text = 'raw_text'
|
||||
raw_text_bytes = raw_text.encode()
|
||||
|
||||
# RSA encryption by public key and decryption by private key
|
||||
encrypted_by_pub_key = public_cipher_rsa2.encrypt(message=raw_text_bytes)
|
||||
decrypted_by_pub_key = private_cipher_rsa.decrypt(encrypted_by_pub_key)
|
||||
assert decrypted_by_pub_key == raw_text_bytes
|
||||
|
||||
# RSA encryption and decryption by private key
|
||||
encrypted_by_private_key = private_cipher_rsa.encrypt(message=raw_text_bytes)
|
||||
decrypted_by_private_key = private_cipher_rsa.decrypt(encrypted_by_private_key)
|
||||
assert decrypted_by_private_key == raw_text_bytes
|
||||
12
api/tests/unit_tests/models/test_account.py
Normal file
12
api/tests/unit_tests/models/test_account.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
def test_account_is_privileged_role() -> None:
|
||||
assert TenantAccountRole.ADMIN == 'admin'
|
||||
assert TenantAccountRole.OWNER == 'owner'
|
||||
assert TenantAccountRole.NORMAL == 'normal'
|
||||
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
|
||||
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
|
||||
assert not TenantAccountRole.is_privileged_role('')
|
||||
@@ -89,7 +89,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
)
|
||||
]
|
||||
|
||||
nodes = workflow_converter._convert_to_http_request_node(
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables
|
||||
@@ -159,7 +159,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
)
|
||||
]
|
||||
|
||||
nodes = workflow_converter._convert_to_http_request_node(
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables
|
||||
|
||||
@@ -9,3 +9,6 @@ dev/pytest/pytest_tools.sh
|
||||
|
||||
# Workflow
|
||||
dev/pytest/pytest_workflow.sh
|
||||
|
||||
# Unit tests
|
||||
dev/pytest/pytest_unit_tests.sh
|
||||
5
dev/pytest/pytest_unit_tests.sh
Executable file
5
dev/pytest/pytest_unit_tests.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
# libs
|
||||
pytest api/tests/unit_tests
|
||||
4
dev/pytest/pytest_vdb.sh
Executable file
4
dev/pytest/pytest_vdb.sh
Executable file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
set -x
|
||||
|
||||
pytest api/tests/integration_tests/vdb/
|
||||
@@ -36,7 +36,7 @@ services:
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
|
||||
standalone:
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.3.1
|
||||
command: ["milvus", "run", "standalone"]
|
||||
12
docker/docker-compose.qdrant.yaml
Normal file
12
docker/docker-compose.qdrant.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
version: '3'
|
||||
services:
|
||||
# Qdrant vector store.
|
||||
qdrant:
|
||||
image: langgenius/qdrant:v1.7.3
|
||||
restart: always
|
||||
volumes:
|
||||
- ./volumes/qdrant:/qdrant/storage
|
||||
environment:
|
||||
QDRANT_API_KEY: 'difyai123456'
|
||||
ports:
|
||||
- "6333:6333"
|
||||
@@ -2,7 +2,7 @@ version: '3'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.6.4
|
||||
image: langgenius/dify-api:0.6.5
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -150,7 +150,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.6.4
|
||||
image: langgenius/dify-api:0.6.5
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -238,7 +238,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.6.4
|
||||
image: langgenius/dify-web:0.6.5
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user