mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 07:44:12 +00:00
Compare commits
39 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b49e0ee2a | ||
|
|
e031ec9359 | ||
|
|
1bd1cd6938 | ||
|
|
81c5a21b8d | ||
|
|
61e4bbabaf | ||
|
|
4cf475680d | ||
|
|
ca4aa340f6 | ||
|
|
767d8a4b05 | ||
|
|
0b8dcaba8f | ||
|
|
af6a318aae | ||
|
|
c6e2900be7 | ||
|
|
963d9b6032 | ||
|
|
b2ee738bb1 | ||
|
|
c8ca3ff404 | ||
|
|
5d8fa2c7af | ||
|
|
58df5e5376 | ||
|
|
348ad1a624 | ||
|
|
73e17d5aa8 | ||
|
|
300d9892a5 | ||
|
|
e47b5b43b8 | ||
|
|
21c9d9e200 | ||
|
|
4f6916c4d8 | ||
|
|
8633957726 | ||
|
|
0850c953b3 | ||
|
|
23e95fd7ab | ||
|
|
e1045f01c6 | ||
|
|
e6d22fc3a0 | ||
|
|
9232244920 | ||
|
|
476eb90a90 | ||
|
|
063191889d | ||
|
|
589099a005 | ||
|
|
a0ec7de058 | ||
|
|
14a19a3da9 | ||
|
|
1b04382a9b | ||
|
|
71e5828d41 | ||
|
|
65a02f7d32 | ||
|
|
acf9174bef | ||
|
|
243ca5b1e2 | ||
|
|
f6059c377c |
30
.github/pull_request_template.md
vendored
Normal file
30
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
# Description
|
||||
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Type of Change
|
||||
|
||||
Please delete options that are not relevant.
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
|
||||
# How Has This Been Tested?
|
||||
|
||||
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
|
||||
|
||||
- [ ] TODO
|
||||
|
||||
# Suggested Checklist:
|
||||
|
||||
- [ ] I have performed a self-review of my own code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
- [ ] `optional` I have made corresponding changes to the documentation
|
||||
- [ ] `optional` I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] `optional` New and existing unit tests pass locally with my changes
|
||||
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
34
.github/workflows/tool-test-sdks.yaml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Run Unit Test For SDKs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
node-version: [16, 18, 20]
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Use Node.js ${{ matrix.node-version }}
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
@@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=resend
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
@@ -39,10 +38,11 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
@@ -87,7 +86,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.4"
|
||||
self.CURRENT_VERSION = "0.5.6"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -209,6 +208,12 @@ class Config:
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
self.RESEND_API_URL = get_env('RESEND_API_URL')
|
||||
# SMTP settings
|
||||
self.SMTP_SERVER = get_env('SMTP_SERVER')
|
||||
self.SMTP_PORT = get_env('SMTP_PORT')
|
||||
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
|
||||
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
|
||||
import json
|
||||
|
||||
from models.model import AppModelConfig
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT']
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
@@ -16,8 +15,10 @@ language_timezone_mapping = {
|
||||
'ko-KR': 'Asia/Seoul',
|
||||
'ru-RU': 'Europe/Moscow',
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
}
|
||||
|
||||
|
||||
def supported_language(lang):
|
||||
if lang in languages:
|
||||
return lang
|
||||
@@ -26,6 +27,7 @@ def supported_language(lang):
|
||||
.format(lang=lang))
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
user_input_form_template = {
|
||||
"en-US": [
|
||||
{
|
||||
@@ -67,6 +69,16 @@ user_input_form_template = {
|
||||
}
|
||||
}
|
||||
],
|
||||
"ua-UK": [
|
||||
{
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "default_input",
|
||||
"required": False,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
demo_model_templates = {
|
||||
@@ -145,7 +157,7 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
@@ -272,7 +284,7 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
@@ -323,5 +335,130 @@ demo_model_templates = {
|
||||
)
|
||||
}
|
||||
],
|
||||
'uk-UA': [{
|
||||
"name": "Помічник перекладу",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.",
|
||||
"mode": "completion",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo-instruct",
|
||||
configs={
|
||||
"prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n",
|
||||
"prompt_variables": [
|
||||
{
|
||||
"key": "target_language",
|
||||
"name": "Цільова мова",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"type": "select",
|
||||
"default": "Ukrainian",
|
||||
"options": [
|
||||
"Chinese",
|
||||
"English",
|
||||
"Japanese",
|
||||
"French",
|
||||
"Russian",
|
||||
"German",
|
||||
"Spanish",
|
||||
"Korean",
|
||||
"Italian",
|
||||
],
|
||||
},
|
||||
],
|
||||
"completion_params": {
|
||||
"max_token": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=json.dumps([
|
||||
{
|
||||
"select": {
|
||||
"label": "Цільова мова",
|
||||
"variable": "target_language",
|
||||
"description": "Мова, на яку ви хочете перекласти.",
|
||||
"default": "Chinese",
|
||||
"required": True,
|
||||
'options': [
|
||||
'Chinese',
|
||||
'English',
|
||||
'Japanese',
|
||||
'French',
|
||||
'Russian',
|
||||
'German',
|
||||
'Spanish',
|
||||
'Korean',
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
}, {
|
||||
"paragraph": {
|
||||
"label": "Запит",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
},
|
||||
{
|
||||
"name": "AI інтерв’юер фронтенду",
|
||||
"icon": "",
|
||||
"icon_background": "",
|
||||
"description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.",
|
||||
"mode": "chat",
|
||||
"model_config": AppModelConfig(
|
||||
provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
configs={
|
||||
"introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
"prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
"prompt_variables": [],
|
||||
"completion_params": {
|
||||
"max_token": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
},
|
||||
opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ",
|
||||
suggested_questions=None,
|
||||
pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"presence_penalty": 0.1,
|
||||
"frequency_penalty": 0.1,
|
||||
},
|
||||
}),
|
||||
user_input_form=None
|
||||
),
|
||||
}
|
||||
],
|
||||
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
@@ -133,8 +132,8 @@ class AppListApi(Resource):
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -46,7 +45,9 @@ class ChatMessageAudioApi(Resource):
|
||||
try:
|
||||
response = AudioService.transcript_asr(
|
||||
tenant_id=app_model.tenant_id,
|
||||
file=file
|
||||
file=file,
|
||||
end_user=None,
|
||||
promot=app_model.app_model_config.pre_prompt
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -72,7 +73,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@@ -83,10 +84,12 @@ class ChatMessageTextApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
app_model = _get_app(app_id, None)
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -113,9 +116,50 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextModesApi(Resource):
|
||||
def get(self, app_id: str):
|
||||
app_model = _get_app(str(app_id))
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('language', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
tenant_id=app_model.tenant_id,
|
||||
language=args['language'],
|
||||
)
|
||||
|
||||
return response
|
||||
except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError:
|
||||
raise AppUnavailableError("Text to audio voices language parameter loss.")
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
except AudioTooLargeServiceError as e:
|
||||
raise AudioTooLargeError(str(e))
|
||||
except UnsupportedAudioTypeServiceError:
|
||||
raise UnsupportedAudioTypeError()
|
||||
except ProviderNotSupportSpeechToTextServiceError:
|
||||
raise ProviderNotSupportSpeechToTextError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception(f"internal server error, {str(e)}.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
|
||||
api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
|
||||
api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
@@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_login
|
||||
from flask import current_app, request
|
||||
from flask_restful import Resource, reqparse
|
||||
@@ -8,7 +7,7 @@ from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
class LoginApi(Resource):
|
||||
@@ -30,6 +29,8 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
# todo: return the user info
|
||||
|
||||
@@ -10,7 +10,7 @@ from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
|
||||
from .. import api
|
||||
|
||||
@@ -76,6 +76,8 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import flask_restful
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
@@ -288,8 +287,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif args['info_list']['data_source_type'] == 'notion_import':
|
||||
@@ -304,8 +303,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@@ -71,7 +69,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]:
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@@ -296,8 +294,8 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -372,8 +370,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -442,8 +440,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
elif dataset.data_source_type == 'notion_import':
|
||||
@@ -456,8 +454,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
None, 'English', dataset_id)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@@ -143,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -234,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
@@ -286,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
|
||||
@@ -76,8 +76,8 @@ class HitTestingApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -86,6 +85,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
return {'data': response.data.decode('latin1')}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_login import current_user
|
||||
@@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -78,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import current_app, request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, abort, fields, marshal_with, reqparse
|
||||
@@ -12,6 +11,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
@@ -72,6 +72,13 @@ class MemberInviteEmailApi(Resource):
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
|
||||
})
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append({
|
||||
'status': 'success',
|
||||
'email': invitee_email,
|
||||
'url': f'{console_web_url}/signin'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
invitation_results.append({
|
||||
'status': 'failed',
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
raise NotFound(f'webapp logo is not found')
|
||||
raise NotFound('webapp logo is not found')
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
|
||||
@@ -32,7 +32,7 @@ class ToolFilePreviewApi(Resource):
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise NotFound(f'file is not found')
|
||||
raise NotFound('file is not found')
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -78,7 +77,7 @@ class AppMetaApi(AppApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -86,6 +86,7 @@ class TextApi(AppApiResource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -46,8 +46,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
@@ -90,8 +90,8 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -182,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
@@ -77,7 +76,7 @@ class AppMeta(WebApiResource):
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
url_prefix = (current_app.config.get("CONSOLE_API_URL")
|
||||
+ f"/console/api/workspaces/current/tool-provider/builtin/")
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
@@ -69,17 +68,23 @@ class AudioApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.text_to_speech_dict['enabled']:
|
||||
raise AppUnavailableError()
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=False
|
||||
)
|
||||
|
||||
@@ -106,7 +111,7 @@ class TextApi(WebApiResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
logging.exception(f"internal server error: {str(e)}")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
@@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from typing import Generator, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import fields, marshal_with, reqparse
|
||||
@@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
yield chunk
|
||||
yield from response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200,
|
||||
mimetype='text/event-stream')
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -17,7 +17,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Before invoke callback
|
||||
@@ -38,7 +38,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None):
|
||||
"""
|
||||
On new chunk callback
|
||||
@@ -58,7 +58,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
After invoke callback
|
||||
@@ -80,7 +80,7 @@ class AgentLLMCallback(Callback):
|
||||
|
||||
def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, cast
|
||||
from typing import cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
@@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
@@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
@@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
@@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]:
|
||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
@@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int:
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@@ -180,7 +181,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@@ -213,8 +214,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
@@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
@@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
@@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@@ -227,7 +228,7 @@ Thought: {agent_scratchpad}
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@@ -260,8 +261,8 @@ Thought: {agent_scratchpad}
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from typing import Generator, List, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import (
|
||||
@@ -84,7 +85,7 @@ class AppRunner:
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
@@ -126,7 +127,7 @@ class AppRunner:
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@@ -295,7 +296,7 @@ class AppRunner:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -38,7 +38,7 @@ class AssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
raise ValueError("App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Generator, Optional, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -118,7 +119,7 @@ class GenerateTaskPipeline:
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -201,7 +202,7 @@ class GenerateTaskPipeline:
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
@@ -353,7 +354,7 @@ class GenerateTaskPipeline:
|
||||
|
||||
yield self._yield_response(response)
|
||||
|
||||
elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)):
|
||||
elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
|
||||
@@ -2,7 +2,8 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Generator, Optional, Tuple, Union, cast
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -27,6 +28,7 @@ from core.entities.application_entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
SensitiveWordAvoidanceEntity,
|
||||
TextToSpeechEntity,
|
||||
)
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@@ -571,7 +573,11 @@ class ApplicationManager:
|
||||
text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict:
|
||||
if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
|
||||
properties['text_to_speech'] = True
|
||||
properties['text_to_speech'] = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
)
|
||||
|
||||
# sensitive word avoidance
|
||||
sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
|
||||
@@ -585,7 +591,7 @@ class ApplicationManager:
|
||||
return AppOrchestrationConfigEntity(**properties)
|
||||
|
||||
def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
|
||||
-> Tuple[Conversation, Message]:
|
||||
-> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any, Generator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
@@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
self._message_agent_thought = None
|
||||
|
||||
@property
|
||||
def agent_loops(self) -> List[AgentLoop]:
|
||||
def agent_loops(self) -> list[AgentLoop]:
|
||||
return self._agent_loops
|
||||
|
||||
def clear_agent_loops(self) -> None:
|
||||
@@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_start(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
|
||||
@@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
|
||||
def on_tool_end(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_inputs: Dict[str, Any],
|
||||
tool_inputs: dict[str, Any],
|
||||
tool_outputs: str,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
@@ -40,22 +39,26 @@ class DatasetIndexToolCallbackHandler:
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_tool_end(self, documents: List[Document]) -> None:
|
||||
def on_tool_end(self, documents: list[Document]) -> None:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
doc_id = document.metadata['doc_id']
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).update(
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def return_retriever_resource_info(self, resource: List):
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
if resource and len(resource) > 0:
|
||||
for item in resource:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
@@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
@@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
@@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
@@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
@@ -16,12 +16,12 @@ class LLMChain(LCLLMChain):
|
||||
model_config: ModelConfigEntity
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
parameters: Dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
@@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM
|
||||
|
||||
class FileExtractor:
|
||||
@classmethod
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
@@ -37,7 +37,7 @@ class FileExtractor:
|
||||
return cls.load_from_file(file_path, return_text, upload_file, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
response = requests.get(url, headers={
|
||||
"User-Agent": USER_AGENT
|
||||
})
|
||||
@@ -53,7 +53,7 @@ class FileExtractor:
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str, return_text: bool = False,
|
||||
upload_file: Optional[UploadFile] = None,
|
||||
is_automatic: bool = False) -> Union[List[Document], str]:
|
||||
is_automatic: bool = False) -> Union[list[Document], str]:
|
||||
input_file = Path(file_path)
|
||||
delimiter = '\n'
|
||||
file_extension = input_file.suffix.lower()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import csv
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import CSVLoader as LCCSVLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
@@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader):
|
||||
self,
|
||||
file_path: str,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[Dict] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = True,
|
||||
):
|
||||
@@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader):
|
||||
self.csv_args = csv_args or {}
|
||||
self.autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
data = []
|
||||
keys = []
|
||||
wb = load_workbook(filename=self._file_path, read_only=True)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.document_loaders.helpers import detect_file_encodings
|
||||
@@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader):
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
tups = self.parse_tups(self._file_path)
|
||||
documents = []
|
||||
for header, value in tups:
|
||||
@@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader):
|
||||
|
||||
return documents
|
||||
|
||||
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
|
||||
def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Convert a markdown file to a dictionary.
|
||||
|
||||
The keys are the headers and the values are the text under each header.
|
||||
|
||||
"""
|
||||
markdown_tups: List[Tuple[Optional[str], str]] = []
|
||||
markdown_tups: list[tuple[Optional[str], str]] = []
|
||||
lines = markdown_text.split("\n")
|
||||
|
||||
current_header = None
|
||||
@@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader):
|
||||
content = re.sub(pattern, r"\1", content)
|
||||
return content
|
||||
|
||||
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
|
||||
def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]:
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, "r", encoding=self._encoding) as f:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
@@ -67,7 +67,7 @@ class NotionLoader(BaseLoader):
|
||||
document_model=document_model
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
self.update_last_edited_time(
|
||||
self._document_model
|
||||
)
|
||||
@@ -78,7 +78,7 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
def _load_data_as_documents(
|
||||
self, notion_obj_id: str, notion_page_type: str
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
docs = []
|
||||
if notion_page_type == 'database':
|
||||
# get all the pages in the database
|
||||
@@ -94,8 +94,8 @@ class NotionLoader(BaseLoader):
|
||||
return docs
|
||||
|
||||
def _get_notion_database_data(
|
||||
self, database_id: str, query_dict: Dict[str, Any] = {}
|
||||
) -> List[Document]:
|
||||
self, database_id: str, query_dict: dict[str, Any] = {}
|
||||
) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
@@ -149,12 +149,12 @@ class NotionLoader(BaseLoader):
|
||||
|
||||
return database_content_list
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
result_lines_arr = []
|
||||
cur_block_id = page_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@@ -216,7 +216,7 @@ class NotionLoader(BaseLoader):
|
||||
cur_block_id = block_id
|
||||
while True:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@@ -280,7 +280,7 @@ class NotionLoader(BaseLoader):
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
@@ -346,7 +346,7 @@ class NotionLoader(BaseLoader):
|
||||
else:
|
||||
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
||||
|
||||
query_dict: Dict[str, Any] = {}
|
||||
query_dict: dict[str, Any] = {}
|
||||
|
||||
res = requests.request(
|
||||
"GET",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.document_loaders import PyPDFium2Loader
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@@ -28,7 +28,7 @@ class PdfLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._upload_file = upload_file
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
plaintext_file_key = ''
|
||||
plaintext_file_exists = False
|
||||
if self._upload_file:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
@@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.email import partition_email
|
||||
elements = partition_email(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.md import partition_md
|
||||
|
||||
elements = partition_md(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.msg import partition_msg
|
||||
|
||||
elements = partition_msg(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
elements = partition_ppt(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.text import partition_text
|
||||
|
||||
elements = partition_text(filename=self._file_path, api_url=self._api_url)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
@@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader):
|
||||
self._file_path = file_path
|
||||
self._api_url = api_url
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def load(self) -> list[Document]:
|
||||
from unstructured.partition.xml import partition_xml
|
||||
|
||||
elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Optional, Sequence, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain.schema import Document
|
||||
from sqlalchemy import func
|
||||
@@ -22,10 +23,10 @@ class DatasetDocumentStore:
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore":
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
@@ -40,7 +41,7 @@ class DatasetDocumentStore:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def docs(self) -> Dict[str, Document]:
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings):
|
||||
self._model_instance = model_instance
|
||||
self._user = user
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
text_embeddings = []
|
||||
try:
|
||||
@@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
@@ -42,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Completion Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class RolePrefixEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
@@ -57,6 +58,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Prompt Template Entity.
|
||||
"""
|
||||
|
||||
class PromptType(Enum):
|
||||
"""
|
||||
Prompt Type.
|
||||
@@ -97,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
"""
|
||||
|
||||
class RetrieveStrategy(Enum):
|
||||
"""
|
||||
Dataset Retrieve Strategy.
|
||||
@@ -143,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
@@ -159,6 +171,7 @@ class AgentToolEntity(BaseModel):
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
@@ -166,6 +179,7 @@ class AgentPromptEntity(BaseModel):
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
|
||||
class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Agent First Prompt Entity.
|
||||
@@ -182,12 +196,14 @@ class AgentScratchpadUnit(BaseModel):
|
||||
thought: Optional[str] = None
|
||||
action_str: Optional[str] = None
|
||||
observation: Optional[str] = None
|
||||
action: Optional[Action] = None
|
||||
action: Optional[Action] = None
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
@@ -202,6 +218,7 @@ class AgentEntity(BaseModel):
|
||||
tools: list[AgentToolEntity] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AppOrchestrationConfigEntity(BaseModel):
|
||||
"""
|
||||
App Orchestration Config Entity.
|
||||
@@ -219,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel):
|
||||
show_retrieve_source: bool = False
|
||||
more_like_this: bool = False
|
||||
speech_to_text: bool = False
|
||||
text_to_speech: bool = False
|
||||
text_to_speech: dict = {}
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ImagePromptMessageFile(PromptMessageFile):
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
# content: Union[str, list[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel):
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
|
||||
def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]:
|
||||
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return None
|
||||
|
||||
def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
|
||||
-> Tuple[ProviderModel, dict]:
|
||||
-> tuple[ProviderModel, dict]:
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
|
||||
@@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel):
|
||||
Model class for provider configuration dict.
|
||||
"""
|
||||
tenant_id: str
|
||||
configurations: Dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
return all_models
|
||||
|
||||
def to_list(self) -> List[ProviderConfiguration]:
|
||||
def to_list(self) -> list[ProviderConfiguration]:
|
||||
"""
|
||||
Convert to list.
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
@@ -93,7 +93,7 @@ class Extensible:
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
with open(json_path, encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
@@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[List[PromptMessage]] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
@@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return app_orchestration_config
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str:
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
@@ -134,13 +134,13 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += f"image has been created and sent to user already, you should tell user to check it now."
|
||||
result += "image has been created and sent to user already, you should tell user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]:
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
@@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]:
|
||||
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]:
|
||||
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
@@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: List[str]
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
@@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
thought: str,
|
||||
observation: str,
|
||||
answer: str,
|
||||
messages_ids: List[str],
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
@@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||
def get_history_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages
|
||||
"""
|
||||
@@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
|
||||
return self.history_prompt_messages
|
||||
|
||||
def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]:
|
||||
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message into agent thought
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Generator, List, Literal, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
@@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
def run(self, conversation: Conversation,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
@@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
app_orchestration_config = self.app_orchestration_config
|
||||
self._repack_app_orchestration_config(app_orchestration_config)
|
||||
|
||||
agent_scratchpad: List[AgentScratchpadUnit] = []
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
|
||||
# check model mode
|
||||
if self.app_orchestration_config.model_config.mode == "completion":
|
||||
@@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
prompt_messages = self.history_prompt_messages
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -238,7 +239,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -473,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
next_iteration = agent_prompt_message.next_iteration
|
||||
|
||||
if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
|
||||
raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
|
||||
raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
|
||||
|
||||
# check instruction, tools, and tool_names slots
|
||||
if not first_prompt.find("{{instruction}}") >= 0:
|
||||
@@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
if not next_iteration.find("{{observation}}") >= 0:
|
||||
raise ValueError("{{observation}} is required in next_iteration")
|
||||
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
|
||||
def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
convert agent scratchpad list to str
|
||||
"""
|
||||
@@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return result
|
||||
|
||||
def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
prompt_messages: List[PromptMessage],
|
||||
tools: List[PromptMessageTool],
|
||||
agent_scratchpad: List[AgentScratchpadUnit],
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
agent_prompt_message: AgentPromptEntity,
|
||||
instruction: str,
|
||||
input: str,
|
||||
) -> List[PromptMessage]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize chain of thought prompt messages, a standard prompt message is like:
|
||||
Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.features.assistant_base_runner import BaseAssistantApplicationRunner
|
||||
@@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
prompt_messages_tools: List[PromptMessageTool] = []
|
||||
prompt_messages_tools: list[PromptMessageTool] = []
|
||||
tool_instances = {}
|
||||
for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
|
||||
try:
|
||||
@@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
agent_thoughts: List[MessageAgentThought] = []
|
||||
agent_thoughts: list[MessageAgentThought] = []
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
|
||||
def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
else:
|
||||
@@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
tool_calls: List[Tuple[str, str, Dict[str, Any]]] = []
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
@@ -277,7 +278,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = f"Please check your tool provider credentials"
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
@@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
@@ -96,7 +96,7 @@ class DatasetRetrievalFeature:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler) \
|
||||
-> Optional[List[BaseTool]]:
|
||||
-> Optional[list[BaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tenant_id: tenant id
|
||||
|
||||
@@ -2,7 +2,7 @@ import concurrent
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -62,7 +62,7 @@ class ExternalDataFetchFeature:
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
query: str) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Query external data tool.
|
||||
:param flask_app: flask app
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from core.entities.application_entities import AppOrchestrationConfigEntity
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
@@ -13,7 +12,7 @@ class ModerationFeature:
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
query: str) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@@ -15,8 +15,8 @@ class MessageFileParser:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
@@ -96,7 +96,7 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
@@ -110,8 +110,8 @@ class MessageFileParser:
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_upload_config: dict) -> dict[FileType, list[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
@@ -119,7 +119,7 @@ class MessageFileParser:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseIndex(ABC):
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
@@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
@@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user