Compare commits

..

102 Commits
0.5.4 ... 0.5.7

Author SHA1 Message Date
takatost
5bd3b02be6 version to 0.5.7 (#2610) 2024-02-28 18:07:13 +08:00
crazywoola
3cf5c1853d Fix: default button behavior (#2609) 2024-02-28 17:34:20 +08:00
takatost
a4d86496e1 fix: notion extractor raise 'NoneType' object has no attribute 'curre… (#2608) 2024-02-28 17:08:27 +08:00
takatost
90bdc85f8c fix: AppParameterApi.get() got an unexpected keyword argument 'end_user' (#2607) 2024-02-28 16:46:50 +08:00
takatost
0828873b52 fix: missing default user for APP service api (#2606) 2024-02-28 16:09:56 +08:00
crazywoola
816b707a16 Fix: explore apps is not shown (#2604) 2024-02-28 15:43:42 +08:00
crazywoola
c9257ab4bf Fix/2559 upload powered by brand image not showing up (#2602) 2024-02-28 15:17:49 +08:00
cola
69ce3b3d33 fix props.appDetail.api_base_url /v1 repeat error (#2601) 2024-02-28 15:13:38 +08:00
crazywoola
c4caa7c401 doc: props.appDetail.api_base_url (#2597) 2024-02-28 13:40:57 +08:00
Joshua
dc93a292c3 Feat/provider mistralai (#2598) 2024-02-28 13:39:55 +08:00
takatost
174ee1b646 fix: parameter user exceeded max length when invoking moonshot llm (#2596) 2024-02-28 12:23:34 +08:00
Joshua
9b1c4f47fb feat:add mistral ai (#2594) 2024-02-28 12:22:57 +08:00
crazywoola
582ba45c00 Fix 500 error when creating from the template and the provider is None (#2591) 2024-02-28 11:27:17 +08:00
Rozstone
f1cbd55007 enhancement: skip fetching to improve user experience when switching … (#2580) 2024-02-27 19:16:22 +08:00
Yeuoly
3a34370422 fix: convert tool messages into user messages in react mode and fill … (#2584) 2024-02-27 19:15:07 +08:00
Bowen Liang
29ab244de6 fix: correct the parent class of CacheEmbedding (#2578) 2024-02-27 18:05:48 +08:00
Jyong
920b2c2b40 Fix/hit test tsne issue (#2581)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 17:30:52 +08:00
Yeuoly
ac96d192a6 fix: parameter type handling in API tool and parser (#2574) 2024-02-27 15:59:11 +08:00
Rozstone
07fbeb6cf0 enhancement: improve client-side code (#2568) 2024-02-27 15:58:57 +08:00
Jyong
fc64cdee64 fix mivlus delete by ids error (#2573)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 12:23:13 +08:00
zxhlyh
0c0e96c55f fix: notion binding (#2572) 2024-02-27 11:59:54 +08:00
Jyong
5b953c1ef2 Fix some RAG bugs (#2570)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-27 11:39:05 +08:00
Bowen Liang
562ca45e07 fix weaviate delete_by_ids (#2565) 2024-02-27 11:14:35 +08:00
crazywoola
6bbd53512e Add Dify Meetup Event on Mar 9 (#2566) 2024-02-27 10:40:26 +08:00
Bowen Liang
e352a8ed1b chore: remove redundant casting flask app config into dict (#2564) 2024-02-27 09:39:26 +08:00
Bowen Liang
e55225e2bc fix typo in error message of supported keyword store (#2560) 2024-02-27 00:47:36 +08:00
Yeuoly
3e63abd335 Feat/json mode (#2563) 2024-02-26 23:34:40 +08:00
Jyong
0620fa3094 Feat/vdb migrate command (#2562)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-26 19:47:29 +08:00
Rozstone
d93288f711 Feat/use searchparams as state (#2554)
Co-authored-by: crazywoola <427733928@qq.com>
2024-02-26 12:52:59 +08:00
Rozstone
ca69af7b97 feat: change max_question_num to 5 (#2520)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-24 09:28:27 +08:00
takatost
952e13fef8 Update README_CN.md (#2550) 2024-02-23 17:38:03 +08:00
Jyong
4be3087642 Fix/new RAG bugs (#2547)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 16:54:15 +08:00
Garfield Dai
49da8a23a8 feat: openai llm get trial or paid models from config. (#2546) 2024-02-23 16:48:58 +08:00
Garfield Dai
3ad943a9eb Feat/openai llm trial paid config (#2545) 2024-02-23 16:12:43 +08:00
zxhlyh
3082093293 fix: webapp name (#2543) 2024-02-23 14:54:03 +08:00
Jyong
b03bbab5ad fix dev/reformat (#2542)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 14:53:24 +08:00
crazywoola
9574730050 Feat/i18n restructure (#2529) 2024-02-23 14:31:06 +08:00
Jyong
91ea6fe4ee Fix/langchain document schema (#2539)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-23 14:16:44 +08:00
Joel
769be13189 chore: add api key and value placeholder (#2538) 2024-02-23 13:55:43 +08:00
Bowen Liang
e42175241e fix: tolerate exceptions in cleaning up index when vector db service unavailable (#2533) 2024-02-23 12:30:39 +08:00
Yeuoly
12257b438b Fix/tool default value (#2536) 2024-02-23 12:02:29 +08:00
Bowen Liang
9ecc736c30 fix: update current tenant id of account when switching tenant (#2530) 2024-02-23 10:51:19 +08:00
Jyong
6c4e6bf1d6 Feat/dify rag (#2528)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-22 23:31:57 +08:00
Jyong
97fe817186 Fix/upload limit (#2521)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2024-02-22 17:16:22 +08:00
Charlie.Wei
52b12ed7eb Voice audition (#2504)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-22 16:06:17 +08:00
Yeuoly
d8ab4474b4 fix: bing search response filter (#2519) 2024-02-22 13:06:55 +08:00
crazywoola
1ecbd95adf Fix #2512 (#2515) 2024-02-22 09:22:57 +08:00
crazywoola
cad6e6624f fix: config not exists (#2513) 2024-02-21 19:27:38 +08:00
crazywoola
3505cbe05c update issue template (#2507) 2024-02-21 14:08:11 +08:00
Joel
e15359e589 fix: api doc example error (#2505) 2024-02-21 12:03:48 +08:00
Yeuoly
edb86f5f5a Feat/stream react (#2498) 2024-02-21 10:45:59 +08:00
Yash_1124
adf2651d1f FEAT: Add DuckDuckGo Search Tool for Enhanced Privacy-Focused Search Functionality (#2499) 2024-02-21 10:42:34 +08:00
Chenhe Gu
5031d64e28 Chore/delete chunk decode error alert (#2500) 2024-02-21 03:17:33 +08:00
Yeuoly
ae3ad59b16 Refactor agent history organization and initialization of agent scrat… (#2495) 2024-02-20 19:03:43 +08:00
Yeuoly
e6cd7b0467 feat: increase max tools (#2497) 2024-02-20 19:03:10 +08:00
crazywoola
97e9f52331 doc: typo in chat (#2492) 2024-02-20 16:08:01 +08:00
Yeuoly
25957d917a Add default values for optional parameters in API tool and parser (#2491) 2024-02-20 16:07:43 +08:00
Jyong
20b932da97 del doc support (#2494)
Co-authored-by: jyong <jyong@dify.ai>
2024-02-20 16:05:09 +08:00
zxhlyh
207080babc fix: audio to text (#2493) 2024-02-20 15:16:46 +08:00
Yeuoly
48bacd01cc fix: incorrect tool name (#2489) 2024-02-20 14:50:57 +08:00
zxhlyh
297d0f1f30 fix: code-based extension (#2490) 2024-02-20 14:49:00 +08:00
zxhlyh
eedbe1b770 fix: chat restart (#2488) 2024-02-20 11:24:27 +08:00
kukuze
5ff6b1da07 Windows local deployment switch "tool“ interface failed (#2483) 2024-02-19 20:03:20 +08:00
takatost
8b49e0ee2a bump version to 0.5.6 (#2482) 2024-02-19 17:13:55 +08:00
crazywoola
e031ec9359 remove: parameters in seeds (#2481) 2024-02-19 17:00:46 +08:00
takatost
1bd1cd6938 fix: event handlers not registered globally (#2479) 2024-02-19 16:04:52 +08:00
Yash_1124
81c5a21b8d FEAT: add image styling in markdown (#2441)
Co-authored-by: crazywoola <427733928@qq.com>
2024-02-19 15:07:45 +08:00
Koen Farell
61e4bbabaf feat: added Ukrainian language support (#2473) 2024-02-19 13:11:23 +08:00
takatost
4cf475680d fix: credential verification of baichuan did not throw all errors (#2475) 2024-02-19 11:52:52 +08:00
Yeuoly
ca4aa340f6 fix: Add model_uid validation for model_uid in Xinference models (#2468) 2024-02-19 10:43:25 +08:00
Joel
767d8a4b05 fix: hybrid search may pass rerank enable false (#2467) 2024-02-18 17:52:05 +08:00
TseIan
0b8dcaba8f Chore: Add type files and unit test ci for Node.js SDK (#2268)
Co-authored-by: xieweicheng <xieweicheng@bytedance.com>
2024-02-18 15:54:14 +08:00
wjryours
af6a318aae fix: windows load provider file error (#2463) 2024-02-18 15:48:25 +08:00
Charlie.Wei
c6e2900be7 Display selected tts voice name (#2459)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-18 15:39:25 +08:00
crazywoola
963d9b6032 Feature/display selected info for tts (#2454) 2024-02-16 20:05:14 +08:00
johnpccd
b2ee738bb1 Ignore SSE comments to support openrouter streaming (#2432) 2024-02-16 10:00:10 +08:00
Charlie.Wei
c8ca3ff404 Tts add voice choose (#2453)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-16 01:10:11 +08:00
Charlie.Wei
5d8fa2c7af Tts add voice choose (#2452)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-16 00:15:22 +08:00
takatost
58df5e5376 fix: tts voice language to zh-Hans instead of zh-CN (#2450) 2024-02-16 00:05:29 +08:00
takatost
348ad1a624 Update pull_request_template.md (#2451) 2024-02-16 00:05:18 +08:00
takatost
73e17d5aa8 Create pull_request_template.md (#2449) 2024-02-15 23:35:59 +08:00
Charlie.Wei
300d9892a5 tts add voice choose (#2391)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-02-15 22:41:18 +08:00
Yeuoly
e47b5b43b8 fix: baichuan frequency_penalty (#2446) 2024-02-14 20:11:41 +08:00
takatost
21c9d9e200 feat: add introduction field in log detail response of chat app (#2445) 2024-02-14 12:38:13 +08:00
Igor Voloc
4f6916c4d8 Update SMTP environment variable name in docker-compose (#2444) 2024-02-14 12:29:27 +08:00
takatost
8633957726 version to 0.5.5 (#2440) 2024-02-13 12:31:49 +08:00
zxhlyh
0850c953b3 fix: variable in opener (#2437) 2024-02-12 22:22:57 +08:00
Yeuoly
23e95fd7ab Fix tool provider credential caching issue (#2433) 2024-02-12 18:17:43 +08:00
takatost
e1045f01c6 pref: optimize add hit count query performance when dataset hit (#2436) 2024-02-12 13:50:43 +08:00
takatost
e6d22fc3a0 fix: account has no owner workspace by member inviting (#2435) 2024-02-12 02:09:01 +08:00
Bowen Liang
9232244920 fix recreating users' default tenant relations when loading user (#2408) 2024-02-12 01:31:40 +08:00
takatost
476eb90a90 fix: List not found in account service (#2434) 2024-02-12 00:56:17 +08:00
Bowen Liang
063191889d chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419) 2024-02-09 15:21:33 +08:00
Bowen Liang
589099a005 fix: possible unsent function call in the last chunk of streaming response in OpenAI provider (#2422) 2024-02-09 14:43:38 +08:00
takatost
a0ec7de058 clean: remove no-use ecc_aes.py (#2426) 2024-02-08 20:47:54 +08:00
Bowen Liang
14a19a3da9 chore: apply ruff's pyflakes linter rules (#2420) 2024-02-08 14:11:10 +08:00
zxhlyh
1b04382a9b fix: chat agent mode content copy (#2418) 2024-02-07 21:23:47 +08:00
JonahCui
71e5828d41 feat: add support for smtp when send email (#2409) 2024-02-07 18:08:41 +08:00
Bowen Liang
65a02f7d32 chore: apply F811 linter rule to eliminate redefined imports and methods (#2412) 2024-02-07 16:28:45 +08:00
WANG Lei
acf9174bef fix: studio/api doc (#2415) 2024-02-07 16:28:09 +08:00
crazywoola
243ca5b1e2 fix: typo in package path of core.splitter (#2411) 2024-02-07 15:34:02 +08:00
zxhlyh
f6059c377c fix: api based extension modal title (#2414) 2024-02-07 15:01:53 +08:00
624 changed files with 11945 additions and 8999 deletions

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: input

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:

View File

@@ -10,7 +10,9 @@ body:
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- label: I confirm that I am using English to file this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
required: true
- label: "Pleas do not modify this template :) and fill in all the required fields."
required: true
- type: input
attributes:

30
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,30 @@
# Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
## Type of Change
Please delete options that are not relevant.
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
# How Has This Been Tested?
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
- [ ] TODO
# Suggested Checklist:
- [ ] I have performed a self-review of my own code
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] My changes generate no new warnings
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [ ] `optional` I have made corresponding changes to the documentation
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
- [ ] `optional` New and existing unit tests pass locally with my changes

34
.github/workflows/tool-test-sdks.yaml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Run Unit Test For SDKs
on:
pull_request:
branches:
- main
jobs:
build:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20]
defaults:
run:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v4
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
cache: ''
cache-dependency-path: 'yarn.lock'
- name: Install Dependencies
run: yarn install
- name: Test
run: yarn test

View File

@@ -21,6 +21,17 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs

View File

@@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend
MAIL_TYPE=
# Mail configuration, support: resend, smtp
MAIL_TYPE=resend
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.gmail.com
SMTP_PORT=587
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=false
# Sentry configuration
SENTRY_DSN=
@@ -124,3 +130,5 @@ UNSTRUCTURED_API_URL=
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import os
from werkzeug.exceptions import Unauthorized
@@ -39,10 +38,11 @@ from extensions import (
from extensions.ext_database import db
from extensions.ext_login import login_manager
from libs.passport import PassportService
# DO NOT REMOVE BELOW
from services.account_service import AccountService
# DO NOT REMOVE BELOW
from events import event_handlers
from models import account, dataset, model, source, task, tool, tools, web
# DO NOT REMOVE ABOVE

View File

@@ -6,15 +6,15 @@ import click
from flask import current_app
from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models.account import Tenant
from models.dataset import Dataset
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account
from models.provider import Provider, ProviderModel
@@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
@click.command('vdb-migrate', help='migrate vector db.')
def vdb_migrate():
"""
Migrate other vector database datas to Qdrant.
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start create qdrant indexes.', fg='green'))
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1
while True:
try:
@@ -140,54 +141,101 @@ def create_qdrant_indexes():
except NotFound:
break
model_manager = ModelManager()
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
)
except Exception:
continue
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
continue
if vector_type == "weaviate":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant":
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
try:
vector.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
vector.create(documents)
except Exception as e:
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit()
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@@ -196,4 +244,4 @@ def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(vdb_migrate)

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import os
import dotenv
@@ -39,7 +38,9 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO',
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
@@ -57,6 +58,8 @@ DEFAULTS = {
'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20
}
@@ -87,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.4"
self.CURRENT_VERSION = "0.5.7"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -183,7 +186,7 @@ class Config:
# Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@@ -209,6 +212,12 @@ class Config:
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
self.RESEND_API_URL = get_env('RESEND_API_URL')
# SMTP settings
self.SMTP_SERVER = get_env('SMTP_SERVER')
self.SMTP_PORT = get_env('SMTP_PORT')
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
# ------------------------
# Workpace Configurations.
@@ -254,8 +263,10 @@ class Config:
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
@@ -280,6 +291,8 @@ class Config:
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
class CloudEditionConfig(Config):

View File

@@ -1,9 +1,8 @@
import json
from models.model import AppModelConfig
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
language_timezone_mapping = {
'en-US': 'America/New_York',
@@ -16,8 +15,10 @@ language_timezone_mapping = {
'ko-KR': 'Asia/Seoul',
'ru-RU': 'Europe/Moscow',
'it-IT': 'Europe/Rome',
'uk-UA': 'Europe/Kyiv',
}
def supported_language(lang):
if lang in languages:
return lang
@@ -26,6 +27,7 @@ def supported_language(lang):
.format(lang=lang))
raise ValueError(error)
user_input_form_template = {
"en-US": [
{
@@ -67,6 +69,16 @@ user_input_form_template = {
}
}
],
"ua-UK": [
{
"paragraph": {
"label": "Запит",
"variable": "default_input",
"required": False,
"default": ""
}
}
],
}
demo_model_templates = {
@@ -145,7 +157,7 @@ demo_model_templates = {
'Italian',
]
}
},{
}, {
"paragraph": {
"label": "Query",
"variable": "query",
@@ -272,7 +284,7 @@ demo_model_templates = {
"意大利语",
]
}
},{
}, {
"paragraph": {
"label": "文本内容",
"variable": "query",
@@ -323,5 +335,130 @@ demo_model_templates = {
)
}
],
'uk-UA': [{
"name": "Помічник перекладу",
"icon": "",
"icon_background": "",
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
"mode": "completion",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo-instruct",
configs={
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
"prompt_variables": [
{
"key": "target_language",
"name": "Цільова мова",
"description": "Мова, на яку ви хочете перекласти.",
"type": "select",
"default": "Ukrainian",
"options": [
"Chinese",
"English",
"Japanese",
"French",
"Russian",
"German",
"Spanish",
"Korean",
"Italian",
],
},
],
"completion_params": {
"max_token": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
},
opening_statement="",
suggested_questions=None,
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
"top_p": 0,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=json.dumps([
{
"select": {
"label": "Цільова мова",
"variable": "target_language",
"description": "Мова, на яку ви хочете перекласти.",
"default": "Chinese",
"required": True,
'options': [
'Chinese',
'English',
'Japanese',
'French',
'Russian',
'German',
'Spanish',
'Korean',
'Italian',
]
}
}, {
"paragraph": {
"label": "Запит",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
},
{
"name": "AI інтерв’юер фронтенду",
"icon": "",
"icon_background": "",
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
"mode": "chat",
"model_config": AppModelConfig(
provider="openai",
model_id="gpt-3.5-turbo",
configs={
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
"prompt_variables": [],
"completion_params": {
"max_token": 300,
"temperature": 0.8,
"top_p": 0.9,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
},
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
suggested_questions=None,
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
"top_p": 0.9,
"presence_penalty": 0.1,
"frequency_penalty": 0.1,
},
}),
user_input_form=None
),
}
],
}

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
@@ -125,19 +124,13 @@ class AppListApi(Resource):
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
if provider_model not in available_models_names:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
if not model_instance:
if not default_model_entity:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model
model_config_dict["model"]["provider"] = default_model_entity.provider
model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,

View File

@@ -1,8 +1,7 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from flask_restful import Resource
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -46,7 +45,8 @@ class ChatMessageAudioApi(Resource):
try:
response = AudioService.transcript_asr(
tenant_id=app_model.tenant_id,
file=file
file=file,
end_user=None,
)
return response
@@ -72,7 +72,7 @@ class ChatMessageAudioApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
@@ -83,10 +83,12 @@ class ChatMessageTextApi(Resource):
def post(self, app_id):
app_id = str(app_id)
app_model = _get_app(app_id, None)
try:
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
@@ -113,9 +115,50 @@ class ChatMessageTextApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
class TextModesApi(Resource):
def get(self, app_id: str):
app_model = _get_app(str(app_id))
try:
parser = reqparse.RequestParser()
parser.add_argument('language', type=str, required=True, location='args')
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args['language'],
)
return response
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
raise AppUnavailableError("Text to audio voices language parameter loss.")
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logging.exception(f"internal server error, {str(e)}.")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
import flask_login
from flask import Response, stream_with_context
@@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_login import current_user

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from decimal import Decimal

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import flask_login
from flask import current_app, request
from flask_restful import Resource, reqparse
@@ -8,7 +7,7 @@ from controllers.console import api
from controllers.console.setup import setup_required
from libs.helper import email
from libs.password import valid_password
from services.account_service import AccountService
from services.account_service import AccountService, TenantService
class LoginApi(Resource):
@@ -30,6 +29,8 @@ class LoginApi(Resource):
except services.errors.account.AccountLoginError:
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
TenantService.create_owner_tenant_if_not_exist(account)
AccountService.update_last_login(account, request)
# todo: return the user info

View File

@@ -10,7 +10,7 @@ from constants.languages import languages
from extensions.ext_database import db
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService
from services.account_service import AccountService, RegisterService, TenantService
from .. import api
@@ -76,6 +76,8 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.utcnow()
db.session.commit()
TenantService.create_owner_tenant_if_not_exist(account)
AccountService.update_last_login(account, request)
token = AccountService.get_account_jwt_token(account)

View File

@@ -9,8 +9,9 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.login import login_required
@@ -173,14 +174,15 @@ class DataSourceNotionApi(Resource):
if not data_source_binding:
raise NotFound('Data source binding not found.')
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id
)
text_docs = loader.load()
text_docs = extractor.extract()
return {
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@@ -192,11 +194,31 @@ class DataSourceNotionApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = args['notion_info_list']
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
indexing_runner = IndexingRunner()
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule'])
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'])
return response, 200

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import flask_restful
from flask import current_app, request
from flask_login import current_user
@@ -16,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@@ -179,9 +179,9 @@ class DatasetApi(Resource):
location='json', store_missing=False,
type=_validate_description_length)
parser.add_argument('indexing_technique', type=str, location='json',
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
@@ -259,7 +259,7 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True,
parser.add_argument('indexing_technique', type=str, required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
@@ -269,6 +269,7 @@ class DatasetIndexingEstimateApi(Resource):
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args['info_list']['file_info_list']['file_ids']
file_details = db.session.query(UploadFile).filter(
@@ -279,37 +280,45 @@ class DatasetIndexingEstimateApi(Resource):
if file_details is None:
raise NotFound("File not found.")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
notion_info_list = args['info_list']['notion_info_list']
for notion_info in notion_info_list:
workspace_id = notion_info['workspace_id']
for page in notion_info['pages']:
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": workspace_id,
"notion_obj_id": page['page_id'],
"notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=args['doc_form']
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response, 200
@@ -509,4 +518,3 @@ api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

View File

@@ -1,6 +1,4 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from typing import List
from flask import request
from flask_login import current_user
@@ -34,6 +32,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.document_fields import (
@@ -71,7 +70,7 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
@@ -97,7 +96,7 @@ class GetProcessRuleApi(Resource):
req_data = request.args
document_id = req_data.get('document_id')
# get default rules
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
@@ -296,8 +295,8 @@ class DatasetInitApi(Resource):
)
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -364,16 +363,22 @@ class DocumentIndexingEstimateApi(DocumentResource):
if not file:
raise NotFound('File not found.')
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file,
document_model=document.doc_form
)
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict, None,
'English', dataset_id)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting],
data_process_rule_dict, document.doc_form,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -404,6 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in ['completed', 'error']:
raise DocumentAlreadyFinishedError()
@@ -426,42 +432,49 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
}
info_list.append(notion_info)
if dataset.data_source_type == 'upload_file':
file_details = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id.in_(info_list)
).all()
if document.data_source_type == 'upload_file':
file_id = data_source_info['upload_file_id']
file_detail = db.session.query(UploadFile).filter(
UploadFile.tenant_id == current_user.current_tenant_id,
UploadFile.id == file_id
).first()
if file_details is None:
raise NotFound("File not found.")
if file_detail is None:
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type="upload_file",
upload_file=file_detail,
document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == 'notion_import':
extract_setting = ExtractSetting(
datasource_type="notion_import",
notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['type'],
"tenant_id": current_user.current_tenant_id
},
document_model=document.doc_form
)
extract_settings.append(extract_setting)
else:
raise ValueError('Data source type not support')
indexing_runner = IndexingRunner()
try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict, None,
'English', dataset_id)
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
data_process_rule_dict, document.doc_form,
'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif dataset.data_source_type == 'notion_import':
indexing_runner = IndexingRunner()
try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list,
data_process_rule_dict,
None, 'English', dataset_id)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else:
raise ValueError('Data source type not support')
return response

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import uuid
from datetime import datetime
@@ -143,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -234,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
@@ -286,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment

View File

@@ -76,8 +76,8 @@ class HitTestingApi(Resource):
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
@@ -86,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}

View File

@@ -1,8 +1,8 @@
# -*- coding:utf-8 -*-
import json
import logging
from collections.abc import Generator
from datetime import datetime
from typing import Generator, Union
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
from flask_login import current_user

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_login import current_user
@@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app
@@ -78,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource):
# get all tools
tools = agent_config.get('tools', [])
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/tool-provider/builtin/")
+ "/console/api/workspaces/current/tool-provider/builtin/")
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from functools import wraps
from flask import current_app, request

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
import logging

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from datetime import datetime
import pytz

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import current_user
from flask_restful import Resource, abort, fields, marshal_with, reqparse
@@ -12,6 +11,7 @@ from libs.helper import TimestampField
from libs.login import login_required
from models.account import Account
from services.account_service import RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
account_fields = {
'id': fields.String,
@@ -72,6 +72,13 @@ class MemberInviteEmailApi(Resource):
'email': invitee_email,
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
})
except AccountAlreadyInTenantError:
invitation_results.append({
'status': 'success',
'email': invitee_email,
'url': f'{console_web_url}/signin'
})
break
except Exception as e:
invitation_results.append({
'status': 'failed',

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from functools import wraps

View File

@@ -41,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource):
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
if not webapp_logo_file_id:
raise NotFound(f'webapp logo is not found')
raise NotFound('webapp logo is not found')
try:
generator, mimetype = FileService.get_public_image_preview(

View File

@@ -32,7 +32,7 @@ class ToolFilePreviewApi(Resource):
)
if not result:
raise NotFound(f'file is not found')
raise NotFound('file is not found')
generator, mimetype = result
except Exception:

View File

@@ -1,27 +0,0 @@
from extensions.ext_database import db
from models.model import EndUser
def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user

View File

@@ -1,17 +1,16 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app
from flask_restful import fields, marshal_with
from flask_restful import fields, marshal_with, Resource
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider
class AppParameterApi(AppApiResource):
class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
@@ -43,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}
@validate_app_token
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config
@@ -65,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}
class AppMetaApi(AppApiResource):
def get(self, app_model: App, end_user):
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config
@@ -78,7 +79,7 @@ class AppMetaApi(AppApiResource):
# get all tools
tools = agent_config.get('tools', [])
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/tool-provider/builtin/")
+ "/console/api/workspaces/current/tool-provider/builtin/")
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:

View File

@@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,8 +30,9 @@ from services.errors.audio import (
)
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()
class TextApi(AppApiResource):
def post(self, app_model: App, end_user):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()
@@ -85,7 +86,8 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=args['user'],
end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)

View File

@@ -1,14 +1,14 @@
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -18,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService
class CompletionApi(AppApiResource):
def post(self, app_model, end_user):
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()
@@ -37,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False
try:
@@ -81,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()
class CompletionStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
class ChatApi(AppApiResource):
def post(self, app_model, end_user):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -113,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
@@ -121,9 +109,6 @@ class ChatApi(AppApiResource):
streaming = args['response_mode'] == 'streaming'
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
response = CompletionService.completion(
app_model=app_model,
@@ -156,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()
class ChatStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()
if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()
user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
return {'result': 'success'}, 200
@@ -182,8 +157,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,53 +1,44 @@
# -*- coding:utf-8 -*-
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService
class ConversationApi(AppApiResource):
class ConversationApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
user = request.get_json().get('user')
if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
@@ -55,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
class ConversationRenameApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -66,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.rename(
app_model,

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@@ -1,30 +1,27 @@
from flask import request
from flask_restful import marshal_with
from flask_restful import Resource, marshal_with
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService
class FileApi(AppApiResource):
class FileApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model, end_user):
def post(self, app_model: App, end_user: EndUser):
file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:

View File

@@ -1,21 +1,18 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, marshal_with, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message
from models.model import App, EndUser
from services.message_service import MessageService
class MessageListApi(AppApiResource):
class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
@@ -71,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()
@@ -80,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
@@ -95,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(AppApiResource):
def post(self, app_model, end_user, message_id):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
@@ -115,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
user=end_user,
message_id=message_id,
check_enabled=False
)

View File

@@ -1,7 +1,6 @@
import json
from flask import request
from flask_login import current_user
from flask_restful import marshal, reqparse
from sqlalchemy import desc
from werkzeug.exceptions import NotFound

View File

@@ -46,8 +46,8 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
@@ -90,8 +90,8 @@ class SegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -182,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment

View File

@@ -1,23 +1,40 @@
# -*- coding:utf-8 -*-
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
def validate_app_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = 'query'
JSON = 'json'
FORM = 'form'
class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
@@ -30,16 +47,35 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()
return view(app_model, None, *args, **kwargs)
return decorated
kwargs['app_model'] = app_model
if view:
if fetch_user_arg:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")
if user_id:
user_id = str(user_id)
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def cloud_edition_billing_resource_check(resource: str,
api_token_type: str,
@@ -129,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token
class AppApiResource(Resource):
method_decorators = [validate_app_token]
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()
if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()
return end_user
class DatasetApiResource(Resource):

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import json
from flask import current_app
@@ -77,7 +76,7 @@ class AppMeta(WebApiResource):
# get all tools
tools = agent_config.get('tools', [])
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ f"/console/api/workspaces/current/tool-provider/builtin/")
+ "/console/api/workspaces/current/tool-provider/builtin/")
for tool in tools:
keys = list(tool.keys())
if len(keys) >= 4:

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
@@ -69,17 +68,23 @@ class AudioApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error: {str(e)}")
raise InternalServerError()
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.text_to_speech_dict['enabled']:
raise AppUnavailableError()
try:
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
end_user=end_user.external_user_id,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
@@ -106,7 +111,7 @@ class TextApi(WebApiResource):
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
logging.exception(f"internal server error: {str(e)}")
raise InternalServerError()

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import reqparse
@@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from typing import Generator, Union
from collections.abc import Generator
from typing import Union
from flask import Response, stream_with_context
from flask_restful import fields, marshal_with, reqparse
@@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
for chunk in response:
yield chunk
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
import uuid
from flask import request

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_restful import fields, marshal_with

View File

@@ -1,4 +1,3 @@
# -*- coding:utf-8 -*-
from functools import wraps
from flask import request

View File

@@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Optional
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.model_runtime.callbacks.base_callback import Callback
@@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Before invoke callback
@@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None):
"""
On new chunk callback
@@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
After invoke callback
@@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> None:
"""
Invoke error callback

View File

@@ -1,4 +1,4 @@
from typing import List, cast
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
@@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens

View File

@@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
@@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),

View File

@@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
@@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
@@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
@@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
@@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
return new_messages
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
@@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/

View File

@@ -1,5 +1,6 @@
import re
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
@@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
@@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
@@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
@@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""

View File

@@ -1,5 +1,6 @@
import re
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
@@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
@@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
@@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
@@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
@@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:

View File

@@ -1,5 +1,6 @@
import time
from typing import Generator, List, Optional, Tuple, Union, cast
from collections.abc import Generator
from typing import Optional, Union, cast
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import (
@@ -84,7 +85,7 @@ class AppRunner:
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: List[PromptMessage]):
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
@@ -126,7 +127,7 @@ class AppRunner:
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> Tuple[List[PromptMessage], Optional[List[str]]]:
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
:param context:
@@ -295,7 +296,7 @@ class AppRunner:
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
query: str) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id

View File

@@ -38,7 +38,7 @@ class AssistantApplicationRunner(AppRunner):
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity

View File

@@ -35,7 +35,7 @@ class BasicApplicationRunner(AppRunner):
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError(f"App not found")
raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity

View File

@@ -1,7 +1,8 @@
import json
import logging
import time
from typing import Generator, Optional, Union, cast
from collections.abc import Generator
from typing import Optional, Union, cast
from pydantic import BaseModel
@@ -118,7 +119,7 @@ class GenerateTaskPipeline:
}
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
@@ -174,7 +175,7 @@ class GenerateTaskPipeline:
'id': self._message.id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'answer': event.llm_result.message.content,
'answer': self._task_state.llm_result.message.content,
'metadata': {},
'created_at': int(self._message.created_at.timestamp())
}
@@ -201,7 +202,7 @@ class GenerateTaskPipeline:
data = self._error_to_stream_response_data(self._handle_error(event))
yield self._yield_response(data)
break
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
if isinstance(event, QueueMessageEndEvent):
self._task_state.llm_result = event.llm_result
else:
@@ -353,7 +354,7 @@ class GenerateTaskPipeline:
yield self._yield_response(response)
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
chunk = event.chunk
delta_text = chunk.delta.message.content
if delta_text is None:

View File

@@ -1,7 +1,7 @@
import logging
import threading
import time
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask import Flask, current_app
from pydantic import BaseModel
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
config: dict[str, Any]
class OutputModerationHandler(BaseModel):

View File

@@ -2,7 +2,8 @@ import json
import logging
import threading
import uuid
from typing import Any, Generator, Optional, Tuple, Union, cast
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from pydantic import ValidationError
@@ -27,6 +28,7 @@ from core.entities.application_entities import (
ModelConfigEntity,
PromptTemplateEntity,
SensitiveWordAvoidanceEntity,
TextToSpeechEntity,
)
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
@@ -571,7 +573,11 @@ class ApplicationManager:
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
if text_to_speech_dict:
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
properties['text_to_speech'] = True
properties['text_to_speech'] = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
# sensitive word avoidance
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
@@ -585,7 +591,7 @@ class ApplicationManager:
return AppOrchestrationConfigEntity(**properties)
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
-> Tuple[Conversation, Message]:
-> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity

View File

@@ -1,7 +1,8 @@
import queue
import time
from collections.abc import Generator
from enum import Enum
from typing import Any, Generator
from typing import Any
from sqlalchemy.orm import DeclarativeMeta

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Optional, Union, cast
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
@@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._message_agent_thought = None
@property
def agent_loops(self) -> List[AgentLoop]:
def agent_loops(self) -> list[AgentLoop]:
return self._agent_loops
def clear_agent_loops(self) -> None:
@@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
pass
@@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
@@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def on_tool_start(
self,
tool_name: str,
tool_inputs: Dict[str, Any],
tool_inputs: dict[str, Any],
) -> None:
"""Do nothing."""
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
def on_tool_end(
self,
tool_name: str,
tool_inputs: Dict[str, Any],
tool_inputs: dict[str, Any],
tool_outputs: str,
) -> None:
"""If not the final action, print out observation."""

View File

@@ -1,9 +1,7 @@
from typing import List
from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment
from models.model import DatasetRetrieverResource
@@ -40,22 +38,26 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query)
db.session.commit()
def on_tool_end(self, documents: List[Document]) -> None:
def on_tool_end(self, documents: list[Document]) -> None:
"""Handle tool end."""
for document in documents:
doc_id = document.metadata['doc_id']
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == doc_id
).update(
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
db.session.commit()
def return_retriever_resource_info(self, resource: List):
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:

View File

@@ -1,6 +1,6 @@
import os
import sys
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
@@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
@@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
@@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
@@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
@@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
model_config: ModelConfigEntity
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
parameters: Dict[str, Any] = {}
parameters: dict[str, Any] = {}
agent_llm_callback: Optional[AgentLLMCallback] = None
def generate(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""

View File

@@ -1,107 +0,0 @@
import tempfile
from pathlib import Path
from typing import List, Optional, Union
import requests
from flask import current_app
from langchain.document_loaders import Docx2txtLoader, TextLoader
from langchain.schema import Document
from core.data_loader.loader.csv_loader import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from core.data_loader.loader.unstructured.unstructured_eml import UnstructuredEmailLoader
from core.data_loader.loader.unstructured.unstructured_markdown import UnstructuredMarkdownLoader
from core.data_loader.loader.unstructured.unstructured_msg import UnstructuredMsgLoader
from core.data_loader.loader.unstructured.unstructured_ppt import UnstructuredPPTLoader
from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredPPTXLoader
from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader
from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader
from extensions.ext_storage import storage
from models.model import UploadFile
SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain']
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
response = requests.get(url, headers={
"User-Agent": USER_AGENT
})
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(url).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, 'wb') as file:
file.write(response.content)
return cls.load_from_file(file_path, return_text)
@classmethod
def load_from_file(cls, file_path: str, return_text: bool = False,
upload_file: Optional[UploadFile] = None,
is_automatic: bool = False) -> Union[List[Document], str]:
input_file = Path(file_path)
delimiter = '\n'
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
if etl_type == 'Unstructured':
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
else MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx', '.doc']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_extension == '.msg':
loader = UnstructuredMsgLoader(file_path, unstructured_api_url)
elif file_extension == '.eml':
loader = UnstructuredEmailLoader(file_path, unstructured_api_url)
elif file_extension == '.ppt':
loader = UnstructuredPPTLoader(file_path, unstructured_api_url)
elif file_extension == '.pptx':
loader = UnstructuredPPTXLoader(file_path, unstructured_api_url)
elif file_extension == '.xml':
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
else:
# txt
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
else TextLoader(file_path, autodetect_encoding=True)
else:
if file_extension == '.xlsx':
loader = ExcelLoader(file_path)
elif file_extension == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif file_extension in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif file_extension in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif file_extension in ['.docx', '.doc']:
loader = Docx2txtLoader(file_path)
elif file_extension == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

View File

@@ -1,55 +0,0 @@
import logging
from typing import List, Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents

View File

@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional, Sequence, cast
from collections.abc import Sequence
from typing import Any, Optional, cast
from langchain.schema import Document
from sqlalchemy import func
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
@@ -22,10 +23,10 @@ class DatasetDocumentStore:
self._document_id = document_id
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
return cls(**config_dict)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Serialize to dict."""
return {
"dataset_id": self._dataset.id,
@@ -40,7 +41,7 @@ class DatasetDocumentStore:
return self._user_id
@property
def docs(self) -> Dict[str, Document]:
def docs(self) -> dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()

View File

@@ -1,14 +1,14 @@
import base64
import logging
from typing import List, Optional, cast
from typing import Optional, cast
import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.entity.embedding import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
@@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
self._model_instance = model_instance
self._user = user
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
text_embeddings = []
try:
@@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
return text_embeddings
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)

View File

@@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Advanced Completion Prompt Template Entity.
"""
class RolePrefixEntity(BaseModel):
"""
Role Prefix Entity.
@@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
"""
Prompt Template Entity.
"""
class PromptType(Enum):
"""
Prompt Type.
@@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.
"""
class RetrieveStrategy(Enum):
"""
Dataset Retrieve Strategy.
@@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
config: dict[str, Any] = {}
class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
class FileUploadEntity(BaseModel):
"""
File Upload Entity.
@@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = {}
class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
@@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
first_prompt: str
next_iteration: str
class AgentScratchpadUnit(BaseModel):
"""
Agent First Prompt Entity.
@@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
thought: Optional[str] = None
action_str: Optional[str] = None
observation: Optional[str] = None
action: Optional[Action] = None
action: Optional[Action] = None
class AgentEntity(BaseModel):
"""
Agent Entity.
"""
class Strategy(Enum):
"""
Agent Strategy.
@@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
tools: list[AgentToolEntity] = None
max_iteration: int = 5
class AppOrchestrationConfigEntity(BaseModel):
"""
App Orchestration Config Entity.
@@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
show_retrieve_source: bool = False
more_like_this: bool = False
speech_to_text: bool = False
text_to_speech: bool = False
text_to_speech: dict = {}
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None

View File

@@ -41,7 +41,7 @@ class ImagePromptMessageFile(PromptMessageFile):
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, List[Union[str, Dict]]]
# content: Union[str, list[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]

View File

@@ -1,8 +1,9 @@
import datetime
import json
import logging
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Optional
from pydantic import BaseModel
@@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.provider_credential_schema else []
)
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
return None
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
-> Tuple[ProviderModel, dict]:
-> tuple[ProviderModel, dict]:
"""
Validate custom model credentials.
@@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
Model class for provider configuration dict.
"""
tenant_id: str
configurations: Dict[str, ProviderConfiguration] = {}
configurations: dict[str, ProviderConfiguration] = {}
def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id)
@@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
return all_models
def to_list(self) -> List[ProviderConfiguration]:
def to_list(self) -> list[ProviderConfiguration]:
"""
Convert to list.

View File

@@ -61,7 +61,7 @@ class Extensible:
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r', encoding='utf-8') as f:
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
@@ -93,7 +93,7 @@ class Extensible:
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
with open(json_path, encoding='utf-8') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(

View File

@@ -1,13 +1,8 @@
import logging
from typing import Optional
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.entities.application_entities import InvokeFrom
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, Message, MessageAnnotation
@@ -45,17 +40,6 @@ class AnnotationReplyFeature:
embedding_provider_name = collection_binding_detail.provider_name
embedding_model_name = collection_binding_detail.model_name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=app_record.tenant_id,
provider=embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model_name
)
# get embedding model
embeddings = CacheEmbedding(model_instance)
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
@@ -71,22 +55,14 @@ class AnnotationReplyFeature:
collection_binding_id=dataset_collection_binding.id
)
vector_index = VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings,
attributes=['doc_id', 'annotation_id', 'app_id']
)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
documents = vector_index.search(
documents = vector.search_by_vector(
query=query,
search_type='similarity_score_threshold',
search_kwargs={
'k': 1,
'score_threshold': score_threshold,
'filter': {
'group_id': [dataset.id]
}
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
)

View File

@@ -1,8 +1,9 @@
import json
import logging
import uuid
from datetime import datetime
from mimetypes import guess_extension
from typing import List, Optional, Tuple, Union, cast
from typing import Optional, Union, cast
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -50,7 +58,7 @@ class BaseAssistantApplicationRunner(AppRunner):
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[List[PromptMessage]] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = prompt_messages
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
@@ -122,7 +132,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return app_orchestration_config
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
@@ -134,13 +144,13 @@ class BaseAssistantApplicationRunner(AppRunner):
result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += f"image has been created and sent to user already, you should tell user to check it now."
result += "image has been created and sent to user already, you should tell user to check it now."
else:
result += f"tool response: {response.message}."
return result
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
"""
@@ -325,7 +335,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return prompt_tool
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
@@ -356,7 +366,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return result
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
"""
Create message file
@@ -404,7 +414,7 @@ class BaseAssistantApplicationRunner(AppRunner):
return result
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: List[str]
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
@@ -449,7 +459,7 @@ class BaseAssistantApplicationRunner(AppRunner):
thought: str,
observation: str,
answer: str,
messages_ids: List[str],
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
"""
Save agent thought
@@ -504,19 +514,8 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
def get_history_prompt_messages(self) -> List[PromptMessage]:
"""
Get history prompt messages
"""
if self.history_prompt_messages is None:
self.history_prompt_messages = db.session.query(PromptMessage).filter(
PromptMessage.message_id == self.message.id,
).order_by(PromptMessage.position.asc()).all()
return self.history_prompt_messages
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
Transform tool message into agent thought
"""
@@ -589,4 +588,60 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.commit()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize agent history
"""
result = []
# check if there is a system message in the beginning of the conversation
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
result.append(prompt_messages[0])
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
for message in messages:
result.append(UserPromptMessage(content=message.query))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
tool_inputs = json.loads(agent_thought.tool_input)
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
name=tool,
tool_call_id=tool_call_id,
))
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
return result

View File

@@ -1,6 +1,7 @@
import json
import re
from typing import Dict, Generator, List, Literal, Union
from collections.abc import Generator
from typing import Literal, Union
from core.application_queue_manager import PublishFrom
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
@@ -11,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -29,7 +31,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, conversation: Conversation,
message: Message,
query: str,
inputs: Dict[str, str],
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
@@ -37,7 +39,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
app_orchestration_config = self.app_orchestration_config
self._repack_app_orchestration_config(app_orchestration_config)
agent_scratchpad: List[AgentScratchpadUnit] = []
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
@@ -56,7 +59,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
prompt_messages = self.history_prompt_messages
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
@@ -83,7 +86,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
}
final_answer = ''
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
@@ -130,61 +133,95 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
llm_result: LLMResult = model_instance.invoke_llm(
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,
tools=[],
stop=app_orchestration_config.model_config.stop,
stream=False,
stream=True,
user=self.user_id,
callbacks=[],
)
# check llm result
if not llm_result:
if not chunks:
raise ValueError("failed to invoke llm")
# get scratchpad
scratchpad = self._extract_response_scratchpad(llm_result.message.content)
agent_scratchpad.append(scratchpad)
# get llm usage
if llm_result.usage:
increase_usage(llm_usage, llm_result.usage)
usage_dict = {}
react_chunks = self._handle_stream_react(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, dict):
scratchpad.agent_response += json.dumps(chunk)
try:
if scratchpad.action:
raise Exception("")
scratchpad.action_str = json.dumps(chunk)
scratchpad.action = AgentScratchpadUnit.Action(
action_name=chunk['action'],
action_input=chunk['action_input']
)
except:
scratchpad.thought += json.dumps(chunk)
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=json.dumps(chunk)
),
usage=None
)
)
else:
scratchpad.agent_response += chunk
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)
agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input=scratchpad.action.action_input if scratchpad.action else '',
thought=scratchpad.thought,
observation='',
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=llm_result.usage)
llm_usage=usage_dict['usage'])
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# publish agent thought if it's not empty and there is a action
if scratchpad.thought and scratchpad.action:
# check if final answer
if not scratchpad.action.action_name.lower() == "final answer":
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=scratchpad.thought
),
usage=llm_result.usage,
),
system_fingerprint=''
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = scratchpad.agent_response or ''
@@ -238,7 +275,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
error_response = f"Please check your tool provider credentials"
error_response = "Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
@@ -259,7 +296,6 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# save scratchpad
scratchpad.observation = observation
scratchpad.agent_response = llm_result.message.content
# save agent thought
self.save_agent_thought(
@@ -268,7 +304,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
tool_input=tool_call_args,
thought=None,
observation=observation,
answer=llm_result.message.content,
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
)
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
@@ -315,6 +351,97 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
system_fingerprint=''
), PublishFrom.APPLICATION_MANAGER)
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
-> Generator[Union[str, dict], None, None]:
def parse_json(json_str):
try:
return json.loads(json_str.strip())
except:
return json_str
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_json(json_text)
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
for response in llm_response:
response = response.delta.message.content
if not isinstance(response, str):
continue
# stream
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
if json_quote_count == 0:
in_json = False
got_json = True
index += steps
continue
else:
if in_json:
json_cache += delta
if got_json:
got_json = False
yield parse_json(json_cache)
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
index += steps
if code_block_cache:
yield code_block_cache
if json_cache:
yield parse_json(json_cache)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
fill in inputs from external data tools
@@ -326,122 +453,40 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
continue
return instruction
def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
def _init_agent_scratchpad(self,
agent_scratchpad: list[AgentScratchpadUnit],
messages: list[PromptMessage]
) -> list[AgentScratchpadUnit]:
"""
extract response from llm response
init agent scratchpad
"""
def extra_quotes() -> AgentScratchpadUnit:
agent_response = content
# try to extract all quotes
pattern = re.compile(r'```(.*?)```', re.DOTALL)
quotes = pattern.findall(content)
# try to extract action from end to start
for i in range(len(quotes) - 1, 0, -1):
"""
1. use json load to parse action
2. use plain text `Action: xxx` to parse action
"""
try:
action = json.loads(quotes[i].replace('```', ''))
action_name = action.get("action")
action_input = action.get("action_input")
agent_thought = agent_response.replace(quotes[i], '')
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
current_scratchpad: AgentScratchpadUnit = None
for message in messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content,
action_str='',
action=None,
observation=None,
)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments)
)
except:
# try to parse action from plain text
action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
# delete action from agent response
agent_thought = agent_response.replace(quotes[i], '')
# remove extra quotes
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=quotes[i],
action=AgentScratchpadUnit.Action(
action_name=action_name[0],
action_input=action_input[0],
)
)
def extra_json():
agent_response = content
# try to extract all json
structures, pair_match_stack = [], []
started_at, end_at = 0, 0
for i in range(len(content)):
if content[i] == '{':
pair_match_stack.append(i)
if len(pair_match_stack) == 1:
started_at = i
elif content[i] == '}':
begin = pair_match_stack.pop()
if not pair_match_stack:
end_at = i + 1
structures.append((content[begin:i+1], (started_at, end_at)))
# handle the last character
if pair_match_stack:
end_at = len(content)
structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
for i in range(len(structures), 0, -1):
try:
json_content, (started_at, end_at) = structures[i - 1]
action = json.loads(json_content)
action_name = action.get("action")
action_input = action.get("action_input")
# delete json content from agent response
agent_thought = agent_response[:started_at] + agent_response[end_at:]
# remove extra quotes like ```(json)*\n\n```
agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
# remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
if action_name and action_input is not None:
return AgentScratchpadUnit(
agent_response=content,
thought=agent_thought,
action_str=json_content,
action=AgentScratchpadUnit.Action(
action_name=action_name,
action_input=action_input,
)
)
except:
pass
agent_scratchpad = extra_quotes()
if agent_scratchpad:
return agent_scratchpad
agent_scratchpad = extra_json()
if agent_scratchpad:
return agent_scratchpad
return AgentScratchpadUnit(
agent_response=content,
thought=content,
action_str='',
action=None
)
except:
pass
agent_scratchpad.append(current_scratchpad)
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
current_scratchpad.observation = message.content
return agent_scratchpad
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
agent_prompt_message: AgentPromptEntity,
):
@@ -473,7 +518,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
next_iteration = agent_prompt_message.next_iteration
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
# check instruction, tools, and tool_names slots
if not first_prompt.find("{{instruction}}") >= 0:
@@ -493,7 +538,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if not next_iteration.find("{{observation}}") >= 0:
raise ValueError("{{observation}} is required in next_iteration")
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
convert agent scratchpad list to str
"""
@@ -506,13 +551,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
return result
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
prompt_messages: List[PromptMessage],
tools: List[PromptMessageTool],
agent_scratchpad: List[AgentScratchpadUnit],
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
agent_scratchpad: list[AgentScratchpadUnit],
agent_prompt_message: AgentPromptEntity,
instruction: str,
input: str,
) -> List[PromptMessage]:
) -> list[PromptMessage]:
"""
organize chain of thought prompt messages, a standard prompt message is like:
Respond to the human as helpfully and accurately as possible.
@@ -555,15 +600,22 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# organize prompt messages
if mode == "chat":
# override system message
overrided = False
overridden = False
prompt_messages = prompt_messages.copy()
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message.content = system_message
overrided = True
overridden = True
break
# convert tool prompt messages to user prompt messages
for idx, prompt_message in enumerate(prompt_messages):
if isinstance(prompt_message, ToolPromptMessage):
prompt_messages[idx] = UserPromptMessage(
content=prompt_message.content
)
if not overrided:
if not overridden:
prompt_messages.insert(0, SystemPromptMessage(
content=system_message,
))

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import Any, Dict, Generator, List, Tuple, Union
from collections.abc import Generator
from typing import Any, Union
from core.application_queue_manager import PublishFrom
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
@@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
)
# convert tools into ModelRuntime Tool format
prompt_messages_tools: List[PromptMessageTool] = []
prompt_messages_tools: list[PromptMessageTool] = []
tool_instances = {}
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
try:
@@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
# continue to run until there is not any tool call
function_call_state = True
agent_thoughts: List[MessageAgentThought] = []
agent_thoughts: list[MessageAgentThought] = []
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
@@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
callbacks=[],
)
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ''
@@ -277,7 +278,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
message_file_ids.append(message_file.id)
except ToolProviderCredentialValidationError as e:
error_response = f"Please check your tool provider credentials"
error_response = "Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
@@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, cast
from typing import Optional, cast
from langchain.tools import BaseTool
@@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[List[BaseTool]]:
-> Optional[list[BaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id

View File

@@ -2,7 +2,7 @@ import concurrent
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple
from typing import Optional
from flask import Flask, current_app
@@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
query: str) -> Tuple[Optional[str], Optional[str]]:
query: str) -> tuple[Optional[str], Optional[str]]:
"""
Query external data tool.
:param flask_app: flask app

View File

@@ -1,5 +1,4 @@
import logging
from typing import Tuple
from core.entities.application_entities import AppOrchestrationConfigEntity
from core.moderation.base import ModerationAction, ModerationException
@@ -13,7 +12,7 @@ class ModerationFeature:
tenant_id: str,
app_orchestration_config_entity: AppOrchestrationConfigEntity,
inputs: dict,
query: str) -> Tuple[bool, dict, str]:
query: str) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import requests
@@ -15,8 +15,8 @@ class MessageFileParser:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> List[FileObj]:
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> list[FileObj]:
"""
validate and transform files arg
@@ -96,7 +96,7 @@ class MessageFileParser:
# return all file objs
return new_files
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
"""
transform message files
@@ -110,8 +110,8 @@ class MessageFileParser:
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
"""
transform files to file objs
@@ -119,7 +119,7 @@ class MessageFileParser:
:param file_upload_config:
:return:
"""
type_file_objs: Dict[FileType, List[FileObj]] = {
type_file_objs: dict[FileType, list[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}

View File

@@ -104,37 +104,17 @@ class HostingConfiguration:
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(
quota_limit=hosted_quota_limit,
restrict_models=[
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=trial_models
)
quotas.append(trial_quota)
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
paid_quota = PaidHostingQuota(
restrict_models=[
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
]
restrict_models=paid_models
)
quotas.append(paid_quota)
@@ -258,3 +238,11 @@ class HostingConfiguration:
return HostedModerationConfig(
enabled=False
)
@staticmethod
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
models_str = app_config.get(env_var)
models_list = models_str.split(",") if models_str else []
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
model_name.strip()]

View File

@@ -1,51 +0,0 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex
from core.index.vector_index.vector_index import VectorIndex
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
provider=dataset.embedding_model_provider,
model=dataset.embedding_model
)
embeddings = CacheEmbedding(embedding_model)
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')
@classmethod
def get_default_high_quality_index(cls, dataset: Dataset):
embeddings = OpenAIEmbeddings(openai_api_key=' ')
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)

View File

@@ -1,305 +0,0 @@
import json
import logging
from abc import abstractmethod
from typing import Any, List, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
@abstractmethod
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct[:]
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")
def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")
try:
self.delete()
except Exception as e:
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.create(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def update_qdrant_dataset(self, dataset: Dataset):
logging.info(f"update_qdrant_dataset {dataset.id}")
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).first()
if segment:
try:
exist = self.text_exists(segment.index_node_id)
if exist:
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.add_texts(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

Some files were not shown because too many files have changed in this diff Show More