Compare commits

..

55 Commits
0.3.4 ... 0.3.7

Author SHA1 Message Date
John Wang
57de19a5ca feat: bump version to 0.3.7 (#540) 2023-07-10 15:23:38 +08:00
zxhlyh
7c00a0b6a3 fix voice input in safari (#537) 2023-07-10 10:16:38 +08:00
Jyong
a93506df18 Fix/dataset clean task (#534) 2023-07-08 17:29:56 +08:00
zxhlyh
a03a92e9db Feat/chat support voice input (#532) 2023-07-07 17:50:42 +08:00
John Wang
feebb5dd1f feat: dataset list add order by created at (#531) 2023-07-07 11:51:48 +08:00
John Wang
6eee7cb42c feat: fix azure embedding Too many inputs problem (#530) 2023-07-07 11:17:36 +08:00
Joel
11baff6740 feat: text generation application support run batch (#529) 2023-07-07 10:35:05 +08:00
zxhlyh
cde1797cc0 feat: max token add tip (#525) 2023-07-06 15:57:04 +08:00
KVOJJJin
d143284d99 Fix: stop embedding status display (#523) 2023-07-06 10:51:30 +08:00
zxhlyh
2b94545190 fix check version api (#520) 2023-07-05 11:11:38 +08:00
John Wang
ed6648a41e feat: dataset list add order by created at (#487) 2023-07-05 11:00:21 +08:00
Joel
5e2c3eeac3 fix: chat app added new var old conversation not work (#511) 2023-07-04 14:33:41 +08:00
Joel
b23d8a912b fix: add missing like i18n (#512) 2023-07-04 14:21:51 +08:00
Joel
4f13f8fd0a fix: change langenius text to dify (#498) 2023-07-02 14:01:11 +08:00
Joel
561c9cabd5 fix: input text repeat (#492) 2023-06-29 17:27:48 +08:00
zxhlyh
39ea967b30 refact common layout (#490) 2023-06-29 15:30:12 +08:00
John Wang
da04ff040b fix: remove document from dataset error when vector index npe (#489) 2023-06-29 13:09:22 +08:00
John Wang
b9b0866a46 fix: generate summary error when tokens=4097 (#488) 2023-06-29 12:54:50 +08:00
Joel
c6ab7eebd9 fix: delete operation style error (#485) 2023-06-29 09:24:31 +08:00
Joel
db4e6d81c5 fix: choose dataset not selected after one page (#481) 2023-06-29 09:22:42 +08:00
John Wang
df68a7c82b feat: Optimize the quality of the title generate (#484) 2023-06-28 19:59:20 +08:00
Joel
838825d747 feat: optimize conversation operation (#479) 2023-06-28 17:53:23 +08:00
crazywoola
a87f6f2837 fix: modal disappear (#478) 2023-06-28 16:44:17 +08:00
John Wang
9d98669e7d fix: dataset destination error (#477) 2023-06-28 15:51:07 +08:00
John Wang
408fbb0c70 fix: title, summary, suggested questions generate (#476) 2023-06-28 15:43:33 +08:00
crazywoola
998f819b04 use sub to operate all (#475) 2023-06-28 14:58:40 +08:00
John Wang
6194b82752 feat: bump to 0.3.6 (#474) 2023-06-28 14:23:20 +08:00
Jyong
334f46d0b6 Fix/json format (#466) 2023-06-28 13:58:50 +08:00
Jyong
2eea114ac0 fix special code (#473) 2023-06-28 13:58:36 +08:00
crazywoola
97e9ebd29a Feature/add is deleted to conversations (#470) 2023-06-28 13:31:51 +08:00
Joel
ec261aea54 feat: conversation app support pin and delete conversation (#467) 2023-06-28 11:16:54 +08:00
Joel
accc5faae3 fix: delete dataset not trigger show start new conversation message (#471) 2023-06-28 10:39:40 +08:00
Joel
0462f09ecc fix: app nav call detail match explore app detail page (#469) 2023-06-27 18:40:24 +08:00
zxhlyh
1226d73159 Feat/refact header (#468) 2023-06-27 18:02:01 +08:00
Jyong
c67ecff3fe Fix/json format (#465) 2023-06-27 17:15:03 +08:00
John Wang
d5b42c09ee fix: template parse error when history include {{any}} (#463) 2023-06-27 16:35:50 +08:00
John Wang
835bf9fd8d fix: template parse error when pre prompt include {{}} (#462) 2023-06-27 15:51:55 +08:00
John Wang
c720f831af feat: optimize template parse (#460) 2023-06-27 15:30:38 +08:00
John Wang
df5763be37 feat: optimize openai error raise (#459) 2023-06-27 12:34:47 +08:00
zxhlyh
80eebc2414 feat: upgrade nextjs version (#457) 2023-06-27 12:12:41 +08:00
zxhlyh
17d196126c Feat/add icons (#450) 2023-06-26 15:36:52 +08:00
Joel
addf150a9e fix: hove x scroll shake (#449) 2023-06-26 13:35:12 +08:00
John Wang
cad1532f7c feat: optimize index_struct copy (#442) 2023-06-25 17:52:22 +08:00
John Wang
951afcaaed feat: optimize weaviate error msg (#441) 2023-06-25 17:05:56 +08:00
John Wang
3241e4015b feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
2023-06-25 16:49:14 +08:00
Bin
1dee5de9b4 bugfix: conversation parameters (#438) 2023-06-25 16:14:42 +08:00
John Wang
742bad93b5 feat: bump version to 0.3.5 (#433) 2023-06-21 16:18:41 +08:00
Joel
bb3cc6bba6 fix: file size limit to 15M (#431) 2023-06-21 16:08:57 +08:00
lisaifei@cvte.com
23ef2262bd fix: filter empty value in xlsx to improve vector similarity hit (#422) 2023-06-21 11:25:52 +08:00
Joel
d637a147ee feat: support batch upload files (#419) 2023-06-21 09:44:01 +08:00
crazywoola
8a4d19d9ba fix: actions 2023-06-21 09:10:07 +08:00
Joel
bea382f0dc fix: dataset can only choose first page data (#425)
Support infinite scroll loader data.
2023-06-20 18:08:28 +08:00
John Wang
8b39e48957 fix REDIS_USERNAME format (#414) 2023-06-19 22:14:47 +08:00
crazywoola
5b4538f021 feat: add more labels 2023-06-19 22:09:02 +08:00
Jyong
36dc05c4da fix chinese encoding (#411) 2023-06-19 18:41:17 +08:00
300 changed files with 7916 additions and 4165 deletions

View File

@@ -27,3 +27,4 @@ jobs:
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement'

2
.gitignore vendored
View File

@@ -147,3 +147,5 @@ docker/volumes/weaviate/*
sdks/python-client/build
sdks/python-client/dist
sdks/python-client/dify_client.egg-info
.vscode/

View File

@@ -22,7 +22,7 @@ CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
# redis configuration
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_USERNAME: ''
REDIS_USERNAME=
REDIS_PASSWORD=difyai123456
REDIS_DB=0
@@ -90,4 +90,4 @@ SQLALCHEMY_ECHO=false
NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret
NOTION_INTERNAL_SECRET=you-internal-secret

View File

@@ -14,7 +14,7 @@ from flask import Flask, request, Response, session
import flask_login
from flask_cors import CORS
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -79,7 +79,6 @@ def initialize_extensions(app):
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)
ext_vector_store.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_session.init_app(app)

View File

@@ -1,15 +1,19 @@
import datetime
import logging
import random
import string
import click
from flask import current_app
from werkzeug.exceptions import NotFound
from core.index.index import IndexBuilder
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset
from models.model import Account
import secrets
import base64
@@ -159,8 +163,39 @@ def generate_upper_string():
return result
@click.command('recreate-all-dataset-indexes', help='Recreate all dataset indexes.')
def recreate_all_dataset_indexes():
click.echo(click.style('Start recreate all dataset indexes.', fg='green'))
recreate_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality')\
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
try:
click.echo('Recreating dataset index: {}'.format(dataset.id))
index = IndexBuilder.get_index(dataset, 'high_quality')
if index and index._is_origin():
index.recreate_dataset(dataset)
recreate_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(click.style('Recreate dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes)

View File

@@ -79,7 +79,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.4"
self.CURRENT_VERSION = "0.3.7"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -187,11 +187,13 @@ class Config:
# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
class CloudEditionConfig(Config):

View File

@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth
@@ -21,4 +21,4 @@ from .datasets import datasets, datasets_document, datasets_segments, file, hit_
from .workspace import workspace, members, providers, account
# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

View File

@@ -22,6 +22,7 @@ model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
@@ -144,6 +145,7 @@ class AppListApi(Resource):
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
@@ -434,6 +436,7 @@ class AppCopy(Resource):
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
model=app_config.model,
user_input_form=app_config.user_input_form,

View File

@@ -0,0 +1,69 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from flask_login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.app.error import AppUnavailableError, \
ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
class ChatMessageAudioApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
app_id = str(app_id)
app_model = _get_app(app_id, 'chat')
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')

View File

@@ -209,6 +209,26 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'completion')
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id)
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204
class ChatConversationApi(Resource):
@@ -356,6 +376,27 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'chat')
@setup_required
@login_required
@account_initialization_required
def delete(self, app_id, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id)
# get app info
app = _get_app(app_id, 'chat')
conversation = db.session.query(Conversation) \
.filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first()
if not conversation:
raise NotFound("Conversation Not Exists.")
conversation.is_deleted = True
db.session.commit()
return {'result': 'success'}, 204

View File

@@ -49,3 +49,27 @@ class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = 'app_more_like_this_disabled'
description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -41,6 +41,7 @@ class ModelConfigResource(Resource):
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),

View File

@@ -10,11 +10,10 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_source.notion import NotionPageReader
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.oauth_data_source import NotionOAuth
from models.dataset import Document
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
@@ -232,15 +231,17 @@ class DataSourceNotionApi(Resource):
).first()
if not data_source_binding:
raise NotFound('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
if page_type == 'page':
page_content = reader.read_page(page_id)
elif page_type == 'database':
page_content = reader.query_database_data(page_id)
else:
page_content = ""
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type
)
text_docs = loader.load()
return {
'content': page_content
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@setup_required

View File

@@ -17,9 +17,7 @@ from controllers.console.datasets.error import NoFileUploadedError, TooManyFiles
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.index.readers.html_parser import HTMLParser
from core.index.readers.pdf_parser import PDFParser
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
@@ -123,31 +121,7 @@ class FilePreviewApi(Resource):
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, filepath)
if extension == 'pdf':
parser = PDFParser({'upload_file': upload_file})
text = parser.parse_file(Path(filepath))
elif extension in ['html', 'htm']:
# Use BeautifulSoup to extract text
parser = HTMLParser()
text = parser.parse_file(Path(filepath))
elif extension == 'xlsx':
parser = XLSXParser()
text = parser.parse_file(filepath)
else:
# ['txt', 'markdown', 'md']
with open(filepath, "rb") as fp:
data = fp.read()
encoding = chardet.detect(data)['encoding']
if encoding:
text = data.decode(encoding=encoding).strip() if data else ''
else:
text = data.decode(encoding='utf-8').strip() if data else ''
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text}

View File

@@ -0,0 +1,66 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api
from controllers.console.app.error import AppUnavailableError, ProviderNotInitializeError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError, \
NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.explore.wraps import InstalledAppResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import AppModelConfig
class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')

View File

@@ -21,6 +21,7 @@ class AppParameterApi(InstalledAppResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -35,6 +36,7 @@ class AppParameterApi(InstalledAppResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -32,8 +32,13 @@ class VersionApi(Resource):
'current_version': args.get('current_version')
})
except Exception as error:
logging.exception("Check update error.")
raise InternalServerError()
logging.warning("Check update version error: {}.".format(str(error)))
return {
'version': args.get('current_version'),
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
content = json.loads(response.content)
return {

View File

@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
api = ExternalApi(bp)
from .app import completion, app, conversation, message
from .app import completion, app, conversation, message, audio
from .dataset import document

View File

@@ -22,6 +22,7 @@ class AppParameterApi(AppApiResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -35,6 +36,7 @@ class AppParameterApi(AppApiResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -0,0 +1,61 @@
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \
ProviderNotSupportSpeechToTextError
from controllers.service_api.wraps import AppApiResource
from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from models.model import App, AppModelConfig
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(AudioApi, '/audio-to-text')

View File

@@ -48,6 +48,26 @@ class ConversationApi(AppApiResource):
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields)
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
ConversationService.delete(app_model, conversation_id, end_user)
return {"result": "success"}, 204
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
class ConversationRenameApi(AppApiResource):
@@ -74,3 +94,4 @@ class ConversationRenameApi(AppApiResource):
api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name')
api.add_resource(ConversationApi, '/conversations')
api.add_resource(ConversationApi, '/conversations/<uuid:c_id>', endpoint='conversation')

View File

@@ -51,3 +51,27 @@ class CompletionRequestError(BaseHTTPException):
description = "Completion request failed."
code = 400
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
from . import completion, app, conversation, message, site, saved_message
from . import completion, app, conversation, message, site, saved_message, audio

View File

@@ -21,6 +21,7 @@ class AppParameterApi(WebApiResource):
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
}
@@ -34,6 +35,7 @@ class AppParameterApi(WebApiResource):
'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
}

View File

@@ -0,0 +1,63 @@
# -*- coding:utf-8 -*-
import logging
from flask import request
from werkzeug.exceptions import InternalServerError
import services
from controllers.web import api
from controllers.web.error import AppUnavailableError, ProviderNotInitializeError, CompletionRequestError, \
ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.web.wraps import WebApiResource
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError
from models.model import App, AppModelConfig
class AudioApi(WebApiResource):
def post(self, app_model: App, end_user):
app_model_config: AppModelConfig = app_model.app_model_config
if not app_model_config.speech_to_text_dict['enabled']:
raise AppUnavailableError()
file = request.files['file']
try:
response = AudioService.transcript(
tenant_id=app_model.tenant_id,
file=file,
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError:
raise ProviderNotInitializeError()
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
api.add_resource(AudioApi, '/audio-to-text')

View File

@@ -62,3 +62,27 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = 'app_suggested_questions_after_answer_disabled'
description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page."
code = 403
class NoAudioUploadedError(BaseHTTPException):
error_code = 'no_audio_uploaded'
description = "Please upload your audio."
code = 400
class AudioTooLargeError(BaseHTTPException):
error_code = 'audio_too_large'
description = "Audio size exceeded. {message}"
code = 413
class UnsupportedAudioTypeError(BaseHTTPException):
error_code = 'unsupported_audio_type'
description = "Audio type not allowed."
code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400

View File

@@ -3,19 +3,10 @@ from typing import Optional
import langchain
from flask import Flask
from jieba.analyse import default_tfidf
from langchain import set_handler
from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
from llama_index import IndexStructType, QueryMode
from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
from pydantic import BaseModel
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.index.keyword_table.stopwords import STOPWORDS
from core.prompt.prompt_template import OneLineFormatter
from core.vector_store.vector_store import VectorStore
from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
class HostedOpenAICredential(BaseModel):
@@ -30,23 +21,8 @@ hosted_llm_credentials = HostedLLMCredentials()
def init_app(app: Flask):
formatter = OneLineFormatter()
DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
}
default_tfidf.stop_words = STOPWORDS
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
set_handler(DifyStdOutCallbackHandler())
if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))

View File

@@ -2,7 +2,7 @@ from typing import Optional
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
from langchain.callbacks import CallbackManager
from langchain.callbacks.manager import CallbackManager
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
@@ -16,23 +16,20 @@ class AgentBuilder:
def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
dataset_tool_callback_handler: DatasetToolCallbackHandler,
agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=agent_loop_gather_callback_handler.model_name,
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
)
tool_callback_manager = CallbackManager([
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
])
for tool in tools:
tool.callback_manager = tool_callback_manager
tool.callbacks = [
agent_loop_gather_callback_handler,
dataset_tool_callback_handler,
DifyStdOutCallbackHandler()
]
prompt = cls.build_agent_prompt_template(
tools=tools,
@@ -54,7 +51,7 @@ class AgentBuilder:
tools=tools,
agent=agent,
memory=memory,
callback_manager=agent_callback_manager,
callbacks=agent_callback_manager,
max_iterations=6,
early_stopping_method="generate",
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit

View File

@@ -12,6 +12,7 @@ from core.conversation_message_task import ConversationMessageTask
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
@@ -64,10 +65,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.completion = response.generations[0][0].text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
@@ -75,21 +72,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = []
self._current_loop = None
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
@@ -151,16 +133,6 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._agent_loops = []
self._current_loop = None
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
# Final Answer

View File

@@ -3,7 +3,6 @@ import logging
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.conversation_message_task import ConversationMessageTask
@@ -11,6 +10,7 @@ from core.conversation_message_task import ConversationMessageTask
class DatasetToolCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
@@ -66,52 +66,3 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
) -> None:
"""Do nothing."""
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass

View File

@@ -1,39 +1,26 @@
from llama_index import Response
from typing import List
from langchain.schema import Document
from extensions.ext_database import db
from models.dataset import DocumentSegment
class IndexToolCallbackHandler:
def __init__(self) -> None:
self._response = None
@property
def response(self) -> Response:
return self._response
def on_tool_end(self, response: Response) -> None:
"""Handle tool end."""
self._response = response
class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler):
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None:
super().__init__()
self.dataset_id = dataset_id
def on_tool_end(self, response: Response) -> None:
def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end."""
for node in response.source_nodes:
index_node_id = node.node.doc_id
for document in documents:
doc_id = document.metadata['doc_id']
# add hit count to document segment
db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.index_node_id == index_node_id
DocumentSegment.index_node_id == doc_id
).update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False

View File

@@ -3,7 +3,7 @@ import time
from typing import Any, Dict, List, Union, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage, BaseMessage
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
@@ -12,6 +12,7 @@ from core.llm.streamable_open_ai import StreamableOpenAI
class LLMCallbackHandler(BaseCallbackHandler):
raise_error: bool = True
def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
conversation_message_task: ConversationMessageTask):
@@ -25,41 +26,41 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
role = 'user'
elif message.type == 'ai':
role = 'assistant'
else:
role = 'system'
real_prompts.append({
"role": role,
"text": message.content
})
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
if 'Chat' in serialized['name']:
real_prompts = []
messages = []
for prompt in prompts:
role, content = prompt.split(': ', maxsplit=1)
if role == 'human':
role = 'user'
message = HumanMessage(content=content)
elif role == 'ai':
role = 'assistant'
message = AIMessage(content=content)
else:
message = SystemMessage(content=content)
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
real_prompt = {
"role": role,
"text": content
}
real_prompts.append(real_prompt)
messages.append(message)
self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
else:
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
}]
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
@@ -95,58 +96,3 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
else:
logging.error(error)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
pass

View File

@@ -1,10 +1,9 @@
import logging
import time
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.entity.chain_result import ChainResult
@@ -14,6 +13,7 @@ from core.conversation_message_task import ConversationMessageTask
class MainChainGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
@@ -50,13 +50,15 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
) -> None:
"""Print out that we are entering a chain."""
if not self._current_chain_result:
self._current_chain_result = ChainResult(
type=serialized['name'],
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
chain_type = serialized['id'][-1]
if chain_type:
self._current_chain_result = ChainResult(
type=chain_type,
prompt=inputs,
started_at=time.perf_counter()
)
self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result)
self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
@@ -74,64 +76,4 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
self.clear_chain_results()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.error(error)
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run on additional input from chains and agents."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
pass
self.clear_chain_results()

View File

@@ -1,9 +1,10 @@
import os
import sys
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
class DifyStdOutCallbackHandler(BaseCallbackHandler):
@@ -13,17 +14,23 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Initialize callback handler."""
self.color = color
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
if 'Chat' in serialized['name']:
for prompt in prompts:
print_text(prompt + "\n", color='blue')
else:
print_text(prompts[0] + "\n", color='blue')
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
@@ -44,8 +51,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized["name"]
print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink')
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
@@ -117,6 +124,26 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""

View File

@@ -1,7 +1,5 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.tool_chain import ToolChain
@@ -14,7 +12,7 @@ class ChainBuilder:
tool=tool,
input_key=kwargs.get('input_key', 'input'),
output_key=kwargs.get('output_key', 'tool_output'),
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
callbacks=[DifyStdOutCallbackHandler()]
)
@classmethod
@@ -27,7 +25,7 @@ class ChainBuilder:
sensitive_words=sensitive_words.split(","),
canned_response=tool_config.get("canned_response", ''),
output_key="sensitive_word_avoidance_output",
callback_manager=CallbackManager([DifyStdOutCallbackHandler()]),
callbacks=[DifyStdOutCallbackHandler()],
**kwargs
)

View File

@@ -1,15 +1,16 @@
"""Base classes for LLM-powered router chains."""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import root_validator
from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
from langchain.schema import BaseOutputParser, OutputParserException
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -51,8 +52,9 @@ class LLMRouterChain(Chain):
raise ValueError
def _call(
self,
inputs: Dict[str, Any]
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
output = cast(
Dict[str, Any],

View File

@@ -1,11 +1,9 @@
from typing import Optional, List
from typing import Optional, List, cast
from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
@@ -18,6 +16,7 @@ from models.dataset import Dataset
class MainChainBuilder:
@classmethod
def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
rest_tokens: int,
conversation_message_task: ConversationMessageTask):
first_input_key = "input"
final_output_key = "output"
@@ -30,6 +29,7 @@ class MainChainBuilder:
tool_chains, chains_output_key = cls.get_agent_chains(
tenant_id=tenant_id,
agent_mode=agent_mode,
rest_tokens=rest_tokens,
memory=memory,
conversation_message_task=conversation_message_task
)
@@ -42,9 +42,8 @@ class MainChainBuilder:
return None
for chain in chains:
# do not add handler into singleton callback manager
if not isinstance(chain.callback_manager, SharedCallbackManager):
chain.callback_manager.add_handler(chain_callback_handler)
chain = cast(Chain, chain)
chain.callbacks.append(chain_callback_handler)
# build main chain
overall_chain = SequentialChain(
@@ -57,7 +56,9 @@ class MainChainBuilder:
return overall_chain
@classmethod
def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
def get_agent_chains(cls, tenant_id: str, agent_mode: dict,
rest_tokens: int,
memory: Optional[BaseChatMemory],
conversation_message_task: ConversationMessageTask):
# agent mode
chains = []
@@ -93,7 +94,8 @@ class MainChainBuilder:
tenant_id=tenant_id,
datasets=datasets,
conversation_message_task=conversation_message_task,
callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
rest_tokens=rest_tokens,
callbacks=[DifyStdOutCallbackHandler()]
)
chains.append(multi_dataset_router_chain)

View File

@@ -1,9 +1,10 @@
import math
import re
from typing import Mapping, List, Dict, Any, Optional
from langchain import LLMChain, PromptTemplate, ConversationChain
from langchain.callbacks import CallbackManager
from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseLanguageModel
from pydantic import Extra
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
@@ -11,10 +12,11 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_tool_builder import DatasetToolBuilder
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
@@ -52,7 +54,7 @@ class MultiDatasetRouterChain(Chain):
router_chain: LLMRouterChain
"""Chain for deciding a destination chain and the input to it."""
dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
dataset_tools: Mapping[str, DatasetTool]
"""Map of name to candidate chains that inputs can be routed to."""
class Config:
@@ -79,41 +81,56 @@ class MultiDatasetRouterChain(Chain):
tenant_id: str,
datasets: List[Dataset],
conversation_message_task: ConversationMessageTask,
rest_tokens: int,
**kwargs: Any,
):
"""Convenience constructor for instantiating from destination prompts."""
llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=1024,
callback_manager=llm_callback_manager
callbacks=[DifyStdOutCallbackHandler()]
)
destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
else ('useful for when you want to answer queries about the ' + d.name))
for d in datasets]
destinations_str = "\n".join(destinations)
router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
destinations=destinations_str
)
router_prompt = PromptTemplate(
template=router_template,
input_variables=["input"],
output_parser=RouterOutputParser(),
)
router_chain = LLMRouterChain.from_llm(llm, router_prompt)
dataset_tools = {}
for dataset in datasets:
dataset_tool = DatasetToolBuilder.build_dataset_tool(
# fulfill description when it is empty
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
continue
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
if k == 0:
continue
dataset_tool = DatasetTool(
name=f"dataset-{dataset.id}",
description=description,
k=k,
dataset=dataset,
response_mode='no_synthesizer', # "compact"
callback_handler=DatasetToolCallbackHandler(conversation_message_task)
callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
)
if dataset_tool:
dataset_tools[dataset.id] = dataset_tool
dataset_tools[str(dataset.id)] = dataset_tool
return cls(
router_chain=router_chain,
@@ -121,9 +138,39 @@ class MultiDatasetRouterChain(Chain):
**kwargs,
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
else:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space
return context_limit_tokens // segment_max_tokens
def _call(
self,
inputs: Dict[str, Any]
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if len(self.dataset_tools) == 0:
return {"text": ''}
@@ -132,13 +179,20 @@ class MultiDatasetRouterChain(Chain):
route = self.router_chain.route(inputs)
if not route.destination:
destination = ''
if route.destination:
pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
match = re.search(pattern, route.destination, re.IGNORECASE)
if match:
destination = match.group()
if not destination:
return {"text": ''}
elif route.destination in self.dataset_tools:
return {"text": self.dataset_tools[route.destination].run(
elif destination in self.dataset_tools:
return {"text": self.dataset_tools[destination].run(
route.next_inputs['input']
)}
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
f"Received invalid destination chain name '{destination}'"
)

View File

@@ -1,5 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@@ -36,7 +37,11 @@ class SensitiveWordAvoidanceChain(Chain):
return self.canned_response
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}

View File

@@ -1,5 +1,6 @@
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun, AsyncCallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.tools import BaseTool
@@ -30,12 +31,20 @@ class ToolChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
input = inputs[self.input_key]
output = self.tool.run(input, self.verbose)
return {self.output_key: output}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
input = inputs[self.input_key]
output = await self.tool.arun(input, self.verbose)

View File

@@ -1,17 +1,18 @@
import logging
from typing import Optional, List, Union, Tuple
from langchain.callbacks import CallbackManager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models.base import BaseChatModel
from langchain.llms import BaseLLM
from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
from langchain.schema import BaseMessage, HumanMessage
from requests.exceptions import ChunkedEncodingError
from core.constant import llm_constant
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, PubHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.error import LLMBadRequestError
from core.llm.llm_builder import LLMBuilder
from core.chain.main_chain_builder import MainChainBuilder
@@ -22,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
ReadOnlyConversationTokenDBStringBufferSharedMemory
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.model import App, AppModelConfig, Account, Conversation, Message
@@ -34,7 +35,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
cls.validate_query_tokens(app.tenant_id, app_model_config, query)
query = PromptBuilder.process_template(query)
memory = None
if conversation:
@@ -48,6 +49,14 @@ class Completion:
inputs = conversation.inputs
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
tenant_id=app.tenant_id,
app_model_config=app_model_config,
query=query,
inputs=inputs
)
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
@@ -64,6 +73,7 @@ class Completion:
main_chain = MainChainBuilder.to_langchain_components(
tenant_id=app.tenant_id,
agent_mode=app_model_config.agent_mode_dict,
rest_tokens=rest_tokens_for_context_and_memory,
memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
conversation_message_task=conversation_message_task
)
@@ -115,7 +125,7 @@ class Completion:
memory=memory
)
final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=final_llm,
@@ -133,18 +143,17 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
# disable template string in query
query_params = OutLinePromptTemplate.from_template(template=query).input_variables
if query_params:
for query_param in query_params:
if query_param not in inputs:
inputs[query_param] = '{' + query_param + '}'
# query_params = JinjaPromptTemplate.from_template(template=query).input_variables
# if query_params:
# for query_param in query_params:
# if query_param not in inputs:
# inputs[query_param] = '{{' + query_param + '}}'
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
if mode == 'completion':
prompt_template = OutLinePromptTemplate.from_template(
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge:
[CONTEXT]
{context}
{{context}}
[END CONTEXT]
When answer to user:
@@ -154,16 +163,16 @@ Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if chain_output else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{query}\n"
+ "{{query}}\n"
)
if chain_output:
inputs['context'] = chain_output
context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
if context_params:
for context_param in context_params:
if context_param not in inputs:
inputs[context_param] = '{' + context_param + '}'
# context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
# if context_params:
# for context_param in context_params:
# if context_param not in inputs:
# inputs[context_param] = '{{' + context_param + '}}'
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
@@ -187,7 +196,7 @@ And answer according to the language of the user's question.
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
@@ -197,7 +206,7 @@ And answer according to the language of the user's question.
human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
[CONTEXT]
{context}
{{context}}
[END CONTEXT]
When answer to user:
@@ -210,7 +219,7 @@ And answer according to the language of the user's question.
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\nHuman: {query}\nAI: "
query_prompt = "\nHuman: {{query}}\nAI: "
if memory:
# append chat histories
@@ -226,11 +235,11 @@ And answer according to the language of the user's question.
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
# disable template string in query
histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
if histories_params:
for histories_param in histories_params:
if histories_param not in human_inputs:
human_inputs[histories_param] = '{' + histories_param + '}'
# histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
# if histories_params:
# for histories_param in histories_params:
# if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories
@@ -247,16 +256,14 @@ And answer according to the language of the user's question.
return messages, ['\nHuman:']
@classmethod
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> CallbackManager:
def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
streaming: bool,
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
if streaming:
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
else:
callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
return CallbackManager(callback_handlers)
return [llm_callback_handler, DifyStdOutCallbackHandler()]
@classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
@@ -293,7 +300,8 @@ And answer according to the language of the user's question.
return memory
@classmethod
def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
llm = LLMBuilder.to_llm_from_model(
tenant_id=tenant_id,
model=app_model_config.model_dict
@@ -302,8 +310,26 @@ And answer according to the language of the user's question.
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
max_tokens = llm.max_tokens
if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
raise LLMBadRequestError("Query is too long")
# get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt(
mode=mode,
llm=llm,
pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs,
chain_output=None,
memory=None
)
prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
else llm.get_num_tokens_from_messages(prompt)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
@classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
@@ -360,7 +386,7 @@ And answer according to the language of the user's question.
streaming=streaming
)
llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
cls.recale_llm_max_tokens(
final_llm=llm,

View File

@@ -10,7 +10,7 @@ from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
from core.llm.provider.llm_provider_service import LLMProviderService
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompt_template import JinjaPromptTemplate
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -78,7 +78,7 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
@@ -86,8 +86,7 @@ class ConversationMessageTask:
pass
if self.app_model_config.pre_prompt:
pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message])
@@ -157,7 +156,7 @@ class ConversationMessageTask:
self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.answer = llm_message.completion.strip() if llm_message.completion else ''
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.provider_response_latency = llm_message.latency
@@ -293,12 +292,12 @@ class PubHandler:
if not user:
raise ValueError("user is required")
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result:{}-{}".format(user_str, task_id)
@classmethod
def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
return "generate_result_stopped:{}-{}".format(user_str, task_id)
def pub_text(self, text: str):
@@ -306,10 +305,10 @@ class PubHandler:
'event': 'message',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
'conversation_id': str(self._conversation.id)
}
}

View File

@@ -0,0 +1,43 @@
import tempfile
from pathlib import Path
from typing import List, Union
from langchain.document_loaders import TextLoader, Docx2txtLoader
from langchain.schema import Document
from core.data_loader.loader.csv import CSVLoader
from core.data_loader.loader.excel import ExcelLoader
from core.data_loader.loader.html import HTMLLoader
from core.data_loader.loader.markdown import MarkdownLoader
from core.data_loader.loader.pdf import PdfLoader
from extensions.ext_storage import storage
from models.model import UploadFile
class FileExtractor:
@classmethod
def load(cls, upload_file: UploadFile, return_text: bool = False) -> Union[List[Document] | str]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
delimiter = '\n'
if input_file.suffix == '.xlsx':
loader = ExcelLoader(file_path)
elif input_file.suffix == '.pdf':
loader = PdfLoader(file_path, upload_file=upload_file)
elif input_file.suffix in ['.md', '.markdown']:
loader = MarkdownLoader(file_path, autodetect_encoding=True)
elif input_file.suffix in ['.htm', '.html']:
loader = HTMLLoader(file_path)
elif input_file.suffix == '.docx':
loader = Docx2txtLoader(file_path)
elif input_file.suffix == '.csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
else:
# txt
loader = TextLoader(file_path, autodetect_encoding=True)
return delimiter.join([document.page_content for document in loader.load()]) if return_text else loader.load()

View File

@@ -0,0 +1,67 @@
import logging
from typing import Optional, Dict, List
from langchain.document_loaders import CSVLoader as LCCSVLoader
from langchain.document_loaders.helpers import detect_file_encodings
from models.dataset import Document
logger = logging.getLogger(__name__)
class CSVLoader(LCCSVLoader):
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
"""Load data into document objects."""
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self._read_from_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(self.file_path, newline="", encoding=encoding.encoding) as csvfile:
docs = self._read_from_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
return docs
def _read_from_file(self, csvfile):
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
if self.source_column is not None
else ''
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs

View File

@@ -0,0 +1,45 @@
import json
import logging
from typing import List
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from openpyxl.reader.excel import load_workbook
logger = logging.getLogger(__name__)
class ExcelLoader(BaseLoader):
"""Load xlxs files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
data = []
keys = []
wb = load_workbook(filename=self._file_path, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
document = Document(page_content=item)
data.append(document)
return data

View File

@@ -0,0 +1,35 @@
import logging
from typing import List
from bs4 import BeautifulSoup
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
logger = logging.getLogger(__name__)
class HTMLLoader(BaseLoader):
"""Load html files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str
):
"""Initialize with file path."""
self._file_path = file_path
def load(self) -> List[Document]:
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text

View File

@@ -0,0 +1,134 @@
import logging
import re
from typing import Optional, List, Tuple, cast
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.helpers import detect_file_encodings
from langchain.schema import Document
logger = logging.getLogger(__name__)
class MarkdownLoader(BaseLoader):
"""Load md files.
Args:
file_path: Path to the file to load.
remove_hyperlinks: Whether to remove hyperlinks from the text.
remove_images: Whether to remove images from the text.
encoding: File encoding to use. If `None`, the file will be loaded
with the default system encoding.
autodetect_encoding: Whether to try to autodetect the file encoding
if the specified encoding fails.
"""
def __init__(
self,
file_path: str,
remove_hyperlinks: bool = True,
remove_images: bool = True,
encoding: Optional[str] = None,
autodetect_encoding: bool = True,
):
"""Initialize with file path."""
self._file_path = file_path
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
self._encoding = encoding
self._autodetect_encoding = autodetect_encoding
def load(self) -> List[Document]:
tups = self.parse_tups(self._file_path)
documents = []
for header, value in tups:
value = value.strip()
if header is None:
documents.append(Document(page_content=value))
else:
documents.append(Document(page_content=f"\n\n{header}\n{value}"))
return documents
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
markdown_tups: List[Tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
current_header = None
current_text = ""
for line in lines:
header_match = re.match(r"^#+\s", line)
if header_match:
if current_header is not None:
markdown_tups.append((current_header, current_text))
current_header = line
current_text = ""
else:
current_text += line + "\n"
markdown_tups.append((current_header, current_text))
if current_header is not None:
# pass linting, assert keys are defined
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]
else:
markdown_tups = [
(key, re.sub("\n", "", value)) for key, value in markdown_tups
]
return markdown_tups
def remove_images(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"!{1}\[\[(.*)\]\]"
content = re.sub(pattern, "", content)
return content
def remove_hyperlinks(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"\[(.*?)\]\((.*?)\)"
content = re.sub(pattern, r"\1", content)
return content
def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
content = ""
try:
with open(filepath, "r", encoding=self._encoding) as f:
content = f.read()
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
logger.debug("Trying encoding: ", encoding.encoding)
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {filepath}") from e
except Exception as e:
raise RuntimeError(f"Error loading {filepath}") from e
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
return self.markdown_to_tups(content)

View File

@@ -1,67 +1,234 @@
"""Notion reader."""
import json
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import List, Dict, Any, Optional
import requests # type: ignore
import requests
from flask import current_app
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
from extensions.ext_database import db
from models.dataset import Document as DocumentModel
from models.source import DataSourceBinding
logger = logging.getLogger(__name__)
INTEGRATION_TOKEN_NAME = "NOTION_INTEGRATION_TOKEN"
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
logger = logging.getLogger(__name__)
# TODO: Notion DB reader coming soon!
class NotionPageReader(BaseReader):
"""Notion Page reader.
class NotionLoader(BaseLoader):
def __init__(
self,
notion_access_token: str,
notion_workspace_id: str,
notion_obj_id: str,
notion_page_type: str,
document_model: Optional[DocumentModel] = None
):
self._document_model = document_model
self._notion_workspace_id = notion_workspace_id
self._notion_obj_id = notion_obj_id
self._notion_page_type = notion_page_type
self._notion_access_token = notion_access_token
Reads a set of Notion pages.
Args:
integration_token (str): Notion integration token.
"""
def __init__(self, integration_token: Optional[str] = None) -> None:
"""Initialize with parameters."""
if integration_token is None:
integration_token = os.getenv(INTEGRATION_TOKEN_NAME)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "
"variable `NOTION_INTEGRATION_TOKEN`."
)
self.token = integration_token
self.headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
done = False
self._notion_access_token = integration_token
@classmethod
def from_document(cls, document_model: DocumentModel):
data_source_info = document_model.data_source_info_dict
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
notion_workspace_id = data_source_info['notion_workspace_id']
notion_obj_id = data_source_info['notion_page_id']
notion_page_type = data_source_info['type']
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
return cls(
notion_access_token=notion_access_token,
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_obj_id,
notion_page_type=notion_page_type,
document_model=document_model
)
def load(self) -> List[Document]:
self.update_last_edited_time(
self._document_model
)
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
return text_docs
def _load_data_as_documents(
self, notion_obj_id: str, notion_page_type: str
) -> List[Document]:
docs = []
if notion_page_type == 'database':
# get all the pages in the database
page_text_documents = self._get_notion_database_data(notion_obj_id)
docs.extend(page_text_documents)
elif notion_page_type == 'page':
page_text_list = self._get_notion_block_data(notion_obj_id)
for page_text in page_text_list:
docs.append(Document(page_content=page_text))
else:
raise ValueError("notion page type not supported")
return docs
def _get_notion_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[Document]:
"""Get all the pages from a Notion database."""
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return []
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
row_dict = {k: v for k, v in data.items() if v}
row_content = ''
for key, value in row_dict.items():
if isinstance(value, dict):
value_dict = {k: v for k, v in value.items() if v}
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
row_content = row_content + f'{key}:{value_content}\n'
else:
row_content = row_content + f'{key}:{value}\n'
document = Document(page_content=row_content)
database_content_list.append(document)
return database_content_list
def _get_notion_block_data(self, page_id: str) -> List[str]:
result_lines_arr = []
cur_block_id = block_id
while not done:
cur_block_id = page_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
# current block's heading
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result_type in HEADING_TYPE:
heading = text
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=1
)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
result_lines_arr = []
cur_block_id = block_id
while True:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
if 'results' not in data or data["results"] is None:
done = True
break
heading = ''
for result in data["results"]:
@@ -98,7 +265,6 @@ class NotionPageReader(BaseReader):
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
@@ -116,7 +282,14 @@ class NotionPageReader(BaseReader):
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
"GET",
block_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
# get table headers text
@@ -129,9 +302,9 @@ class NotionPageReader(BaseReader):
table_header_cell_texts.append(text)
# get table columns text and format
results = data["results"]
for i in range(len(results)-1):
for i in range(len(results) - 1):
column_texts = []
tabel_column_cells = data["results"][i+1]['table_row']['cells']
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
for j in range(len(tabel_column_cells)):
if tabel_column_cells[j]:
for table_column_cell_text in tabel_column_cells[j]:
@@ -149,221 +322,58 @@ class NotionPageReader(BaseReader):
result_lines = "\n".join(result_lines_arr)
return result_lines
def _read_parent_blocks(self, block_id: str, num_tabs: int = 0) -> List[str]:
"""Read a block."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", block_url, headers=self.headers, json=query_dict
def update_last_edited_time(self, document_model: DocumentModel):
if not document_model:
return
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
DocumentModel.data_source_info: json.dumps(data_source_info)
}
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
db.session.commit()
def get_notion_last_edited_time(self) -> str:
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == 'database':
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
else:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET",
retrieve_page_url,
headers={
"Authorization": "Bearer " + self._notion_access_token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
},
json=query_dict
)
data = res.json()
return data["last_edited_time"]
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
)
data = res.json()
# current block's heading
heading = ''
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if result_type == 'table':
result_block_id = result["id"]
text = self._read_table_rows(result_block_id)
text += "\n\n"
result_lines_arr.append(text)
else:
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
# skip if doesn't have text object
if "text" in rich_text:
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result_type in HEADING_TYPE:
heading = text
).first()
result_block_id = result["id"]
has_children = result["has_children"]
block_type = result["type"]
if has_children and block_type != 'child_page':
children_text = self._read_block(
result_block_id, num_tabs=num_tabs + 1
)
cur_result_text_arr.append(children_text)
if not data_source_binding:
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
f'and notion workspace {notion_workspace_id}')
cur_result_text = "\n".join(cur_result_text_arr)
cur_result_text += "\n\n"
if result_type in HEADING_TYPE:
result_lines_arr.append(cur_result_text)
else:
result_lines_arr.append(f'{heading}\n{cur_result_text}')
if data["next_cursor"] is None:
done = True
break
else:
cur_block_id = data["next_cursor"]
return result_lines_arr
def read_page(self, page_id: str) -> str:
"""Read a page."""
return self._read_block(page_id)
def read_page_as_documents(self, page_id: str) -> List[str]:
"""Read a page as documents."""
return self._read_parent_blocks(page_id)
def query_database_data(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> str:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
database_content_list = []
if 'results' not in data or data["results"] is None:
return ""
for result in data["results"]:
properties = result['properties']
data = {}
for property_name, property_value in properties.items():
type = property_value['type']
if type == 'multi_select':
value = []
multi_select_list = property_value[type]
for multi_select in multi_select_list:
value.append(multi_select['name'])
elif type == 'rich_text' or type == 'title':
if len(property_value[type]) > 0:
value = property_value[type][0]['plain_text']
else:
value = ''
elif type == 'select' or type == 'status':
if property_value[type]:
value = property_value[type]['name']
else:
value = ''
else:
value = property_value[type]
data[property_name] = value
database_content_list.append(json.dumps(data))
return "\n\n".join(database_content_list)
def query_database(
self, database_id: str, query_dict: Dict[str, Any] = {}
) -> List[str]:
"""Get all the pages from a Notion database."""
res = requests.post\
(
DATABASE_URL_TMPL.format(database_id=database_id),
headers=self.headers,
json=query_dict,
)
data = res.json()
page_ids = []
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
return page_ids
def search(self, query: str) -> List[str]:
"""Search Notion page given a text query."""
done = False
next_cursor: Optional[str] = None
page_ids = []
while not done:
query_dict = {
"query": query,
}
if next_cursor is not None:
query_dict["start_cursor"] = next_cursor
res = requests.post(SEARCH_URL, headers=self.headers, json=query_dict)
data = res.json()
for result in data["results"]:
page_id = result["id"]
page_ids.append(page_id)
if data["next_cursor"] is None:
done = True
break
else:
next_cursor = data["next_cursor"]
return page_ids
def load_data(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
"""Load data from the input directory.
Args:
page_ids (List[str]): List of page ids to load.
Returns:
List[Document]: List of documents.
"""
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_ids = self.query_database(database_id)
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text = self.read_page(page_id)
docs.append(Document(page_text))
return docs
def load_data_as_documents(
self, page_ids: List[str] = [], database_id: Optional[str] = None
) -> List[Document]:
if not page_ids and not database_id:
raise ValueError("Must specify either `page_ids` or `database_id`.")
docs = []
if database_id is not None:
# get all the pages in the database
page_text = self.query_database_data(database_id)
docs.append(Document(page_text))
else:
for page_id in page_ids:
page_text_list = self.read_page_as_documents(page_id)
for page_text in page_text_list:
docs.append(Document(page_text))
return docs
def get_page_last_edited_time(self, page_id: str) -> str:
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=page_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json()
return data["last_edited_time"]
def get_database_last_edited_time(self, database_id: str) -> str:
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=database_id)
query_dict: Dict[str, Any] = {}
res = requests.request(
"GET", retrieve_page_url, headers=self.headers, json=query_dict
)
data = res.json()
return data["last_edited_time"]
if __name__ == "__main__":
reader = NotionPageReader()
logger.info(reader.search("What I"))
return data_source_binding.access_token

View File

@@ -0,0 +1,55 @@
import logging
from typing import List, Optional
from langchain.document_loaders import PyPDFium2Loader
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from extensions.ext_storage import storage
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfLoader(BaseLoader):
"""Load pdf files.
Args:
file_path: Path to the file to load.
"""
def __init__(
self,
file_path: str,
upload_file: Optional[UploadFile] = None
):
"""Initialize with file path."""
self._file_path = file_path
self._upload_file = upload_file
def load(self) -> List[Document]:
plaintext_file_key = ''
plaintext_file_exists = False
if self._upload_file:
if self._upload_file.hash:
plaintext_file_key = 'upload_files/' + self._upload_file.tenant_id + '/' \
+ self._upload_file.hash + '.0625.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = PyPDFium2Loader(file_path=self._file_path).load()
text_list = []
for document in documents:
text_list.append(document.page_content)
text = "\n\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return documents

View File

@@ -1,10 +1,6 @@
from typing import Any, Dict, Optional, Sequence
import tiktoken
from llama_index.data_structs import Node
from llama_index.docstore.types import BaseDocumentStore
from llama_index.docstore.utils import json_to_doc
from llama_index.schema import BaseDocument
from langchain.schema import Document
from sqlalchemy import func
from core.llm.token_calculator import TokenCalculator
@@ -12,7 +8,7 @@ from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
class DatesetDocumentStore(BaseDocumentStore):
class DatesetDocumentStore:
def __init__(
self,
dataset: Dataset,
@@ -48,7 +44,7 @@ class DatesetDocumentStore(BaseDocumentStore):
return self._embedding_model_name
@property
def docs(self) -> Dict[str, BaseDocument]:
def docs(self) -> Dict[str, Document]:
document_segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id
).all()
@@ -56,13 +52,20 @@ class DatesetDocumentStore(BaseDocumentStore):
output = {}
for document_segment in document_segments:
doc_id = document_segment.index_node_id
result = self.segment_to_dict(document_segment)
output[doc_id] = json_to_doc(result)
output[doc_id] = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
return output
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
self, docs: Sequence[Document], allow_update: bool = True
) -> None:
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
DocumentSegment.document == self._document_id
@@ -72,23 +75,20 @@ class DatesetDocumentStore(BaseDocumentStore):
max_position = 0
for doc in docs:
if doc.is_doc_id_none:
raise ValueError("doc_id not set")
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
if not isinstance(doc, Node):
raise ValueError("doc must be a Node")
segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False)
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
# NOTE: doc could already exist in the store, but we overwrite it
if not allow_update and segment_document:
raise ValueError(
f"doc_id {doc.get_doc_id()} already exists. "
f"doc_id {doc.metadata['doc_id']} already exists. "
"Set allow_update to True to overwrite."
)
# calc embedding use tokens
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text())
tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content)
if not segment_document:
max_position += 1
@@ -97,19 +97,19 @@ class DatesetDocumentStore(BaseDocumentStore):
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
index_node_id=doc.get_doc_id(),
index_node_hash=doc.get_doc_hash(),
index_node_id=doc.metadata['doc_id'],
index_node_hash=doc.metadata['doc_hash'],
position=max_position,
content=doc.get_text(),
word_count=len(doc.get_text()),
content=doc.page_content,
word_count=len(doc.page_content),
tokens=tokens,
created_by=self._user_id,
)
db.session.add(segment_document)
else:
segment_document.content = doc.get_text()
segment_document.index_node_hash = doc.get_doc_hash()
segment_document.word_count = len(doc.get_text())
segment_document.content = doc.page_content
segment_document.index_node_hash = doc.metadata['doc_hash']
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
db.session.commit()
@@ -121,7 +121,7 @@ class DatesetDocumentStore(BaseDocumentStore):
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
) -> Optional[Document]:
document_segment = self.get_document_segment(doc_id)
if document_segment is None:
@@ -130,8 +130,15 @@ class DatesetDocumentStore(BaseDocumentStore):
else:
return None
result = self.segment_to_dict(document_segment)
return json_to_doc(result)
return Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
document_segment = self.get_document_segment(doc_id)
@@ -164,15 +171,6 @@ class DatesetDocumentStore(BaseDocumentStore):
return document_segment.index_node_hash
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))
def get_document_segment(self, doc_id: str) -> DocumentSegment:
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self._dataset.id,
@@ -180,11 +178,3 @@ class DatesetDocumentStore(BaseDocumentStore):
).first()
return document_segment
def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]:
return {
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"text": segment.content,
"__type__": Node.get_type()
}

View File

@@ -1,51 +0,0 @@
from typing import Any, Dict, Optional, Sequence
from llama_index.docstore.types import BaseDocumentStore
from llama_index.schema import BaseDocument
class EmptyDocumentStore(BaseDocumentStore):
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore":
return cls()
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dict."""
return {}
@property
def docs(self) -> Dict[str, BaseDocument]:
return {}
def add_documents(
self, docs: Sequence[BaseDocument], allow_update: bool = True
) -> None:
pass
def document_exists(self, doc_id: str) -> bool:
"""Check if document exists."""
return False
def get_document(
self, doc_id: str, raise_error: bool = True
) -> Optional[BaseDocument]:
return None
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
pass
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"""Set the hash for a given doc_id."""
pass
def get_document_hash(self, doc_id: str) -> Optional[str]:
"""Get the stored hash for a document, if it exists."""
return None
def update_docstore(self, other: "BaseDocumentStore") -> None:
"""Update docstore.
Args:
other (BaseDocumentStore): docstore to update from
"""
self.add_documents(list(other.docs.values()))

View File

@@ -0,0 +1,72 @@
import logging
from typing import List
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError
from extensions.ext_database import db
from libs import helper
from models.dataset import Embedding
class CacheEmbedding(Embeddings):
def __init__(self, embeddings: Embeddings):
self._embeddings = embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
# use doc embedding cache or store if not exists
text_embeddings = []
embedding_queue_texts = []
for text in texts:
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
text_embeddings.append(embedding.get_embedding())
else:
embedding_queue_texts.append(text)
embedding_results = self._embeddings.embed_documents(embedding_queue_texts)
i = 0
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results[i])
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
i += 1
text_embeddings.extend(embedding_results)
return text_embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(hash=hash).first()
if embedding:
return embedding.get_embedding()
embedding_results = self._embeddings.embed_query(text)
try:
embedding = Embedding(hash=hash)
embedding.set_embedding(embedding_results)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to db')
return embedding_results

View File

@@ -1,214 +0,0 @@
from typing import Optional, Any, List
import openai
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
_TEXT_MODE_MODEL_DICT
from tenacity import wait_random_exponential, retry, stop_after_attempt
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding"
]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data]
class OpenAIEmbedding(BaseEmbedding):
def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Init params."""
new_kwargs = {}
if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']
super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name
self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')
@handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.
By default, this is a wrapper around _get_text_embedding.
Can be overriden for batch queries.
"""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(self._get_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
if self.openai_api_type and self.openai_api_type == 'azure':
embeddings = []
for text in texts:
embeddings.append(await self._aget_text_embedding(text))
return embeddings
if self.deployment_name is not None:
engine = self.deployment_name
else:
key = (self.mode, self.model)
if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings

View File

@@ -1,7 +1,8 @@
import logging
from langchain import PromptTemplate
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage, OutputParserException
from langchain.schema import HumanMessage, OutputParserException, BaseMessage
from core.constant import llm_constant
from core.llm.llm_builder import LLMBuilder
@@ -10,7 +11,7 @@ from core.llm.token_calculator import TokenCalculator
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
@@ -22,10 +23,10 @@ class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT
prompt = prompt.format(query=query, answer=answer)
prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name='gpt-3.5-turbo',
max_tokens=50
)
@@ -39,11 +40,12 @@ class LLMGenerator:
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model = 'gpt-3.5-turbo'
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context)
rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
@@ -51,14 +53,17 @@ class LLMGenerator:
continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name=model,
max_tokens=max_tokens
)
@@ -91,8 +96,8 @@ class LLMGenerator:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = OutLinePromptTemplate(
template="{histories}\n{format_instructions}\nquestions:\n",
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
)
@@ -101,7 +106,7 @@ class LLMGenerator:
llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name=generate_base_model,
model_name='gpt-3.5-turbo',
temperature=0,
max_tokens=256
)
@@ -113,6 +118,8 @@ class LLMGenerator:
try:
output = llm(query)
if isinstance(output, BaseMessage):
output = output.content
questions = output_parser.parse(output)
except Exception:
logging.exception("Error generating suggested questions after answer")

59
api/core/index/base.py Normal file
View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from abc import abstractmethod, ABC
from typing import List, Any
from langchain.schema import Document, BaseRetriever
from models.dataset import Dataset
class BaseIndex(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@abstractmethod
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError
@abstractmethod
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
raise NotImplementedError
@abstractmethod
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]

41
api/core/index/index.py Normal file
View File

@@ -0,0 +1,41 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class IndexBuilder:
@classmethod
def get_index(cls, dataset: Dataset, indexing_technique: str, ignore_high_quality_check: bool = False):
if indexing_technique == "high_quality":
if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality':
return None
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
return VectorIndex(
dataset=dataset,
config=current_app.config,
embeddings=embeddings
)
elif indexing_technique == "economy":
return KeywordTableIndex(
dataset=dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=10
)
)
else:
raise ValueError('Unknown indexing technique')

View File

@@ -1,60 +0,0 @@
from langchain.callbacks import CallbackManager
from llama_index import ServiceContext, PromptHelper, LLMPredictor
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.embedding.openai_embedding import OpenAIEmbedding
from core.llm.llm_builder import LLMBuilder
class IndexBuilder:
@classmethod
def get_default_service_context(cls, tenant_id: str) -> ServiceContext:
# set number of output tokens
num_output = 512
# only for verbose
callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='text-davinci-003',
temperature=0,
max_tokens=num_output,
callback_manager=callback_manager,
)
llm_predictor = LLMPredictor(llm=llm)
# These parameters here will affect the logic of segmenting the final synthesized response.
# The number of refinement iterations in the synthesis process depends
# on whether the length of the segmented output exceeds the max_input_size.
prompt_helper = PromptHelper(
max_input_size=3500,
num_output=num_output,
max_chunk_overlap=20
)
provider = LLMBuilder.get_default_provider(tenant_id)
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002'
)
return ServiceContext.from_defaults(
llm_predictor=llm_predictor,
prompt_helper=prompt_helper,
embed_model=OpenAIEmbedding(**model_credentials),
)
@classmethod
def get_fake_llm_service_context(cls, tenant_id: str) -> ServiceContext:
llm = LLMBuilder.to_llm(
tenant_id=tenant_id,
model_name='fake'
)
return ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)

View File

@@ -1,159 +0,0 @@
import re
from typing import (
Any,
Dict,
List,
Set,
Optional
)
import jieba.analyse
from core.index.keyword_table.stopwords import STOPWORDS
from llama_index.indices.query.base import IS
from llama_index import QueryMode
from llama_index.indices.base import QueryMap
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
def jieba_extract_keywords(
text_chunk: str,
max_keywords: Optional[int] = None,
expand_with_subtokens: bool = True,
) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text_chunk,
topK=max_keywords,
)
if expand_with_subtokens:
return set(expand_tokens_with_subtokens(keywords))
else:
return set(keywords)
def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex):
"""GPT JIEBA Keyword Table Index.
This index uses a JIEBA keyword extractor to extract keywords from the text.
"""
def _extract_keywords(self, text: str) -> Set[str]:
"""Extract keywords from text."""
return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk)
@classmethod
def get_query_map(self) -> QueryMap:
"""Get query map."""
super_map = super().get_query_map()
super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery
return super_map
def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
# get set of ids that correspond to node
node_idxs_to_delete = {doc_id}
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in self._index_struct.table.items():
if node_idxs_to_delete.intersection(node_idxs):
self._index_struct.table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not self._index_struct.table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del self._index_struct.table[keyword]
class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery):
"""GPT Keyword Table Index JIEBA Query.
Extracts keywords using JIEBA keyword extractor.
Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`.
.. code-block:: python
response = index.query("<query_str>", mode="jieba")
See BaseGPTKeywordTableQuery for arguments.
"""
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)
def _get_keywords(self, query_str: str) -> List[str]:
"""Extract keywords."""
return list(
jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query)
)

View File

@@ -1,135 +0,0 @@
import json
from typing import List, Optional
from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding
from llama_index.data_structs import KeywordTable, Node
from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex
from llama_index.indices.registry import load_index_struct_from_dict
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.docstore.empty_docstore import EmptyDocumentStore
from core.index.index_builder import IndexBuilder
from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
from core.llm.llm_builder import LLMBuilder
from extensions.ext_database import db
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment
class KeywordTableIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
index_struct = KeywordTable()
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node in nodes:
keywords = index._extract_keywords(node.get_text())
self.update_segment_keywords(node.doc_id, list(keywords))
index._index_struct.add_node(list(keywords), node)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
def del_nodes(self, node_ids: List[str]):
llm = LLMBuilder.to_llm(
tenant_id=self._dataset.tenant_id,
model_name='fake'
)
service_context = ServiceContext.from_defaults(
llm_predictor=LLMPredictor(llm=llm),
embed_model=OpenAIEmbedding()
)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return
else:
index_struct_dict = dataset_keyword_table.keyword_table_dict
index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict)
# create index
index = GPTJIEBAKeywordTableIndex(
index_struct=index_struct,
docstore=EmptyDocumentStore(),
service_context=service_context
)
for node_id in node_ids:
index.delete(node_id)
index_struct_dict = index.index_struct.to_dict()
if not dataset_keyword_table:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self._dataset.id,
keyword_table=json.dumps(index_struct_dict)
)
db.session.add(dataset_keyword_table)
else:
dataset_keyword_table.keyword_table = json.dumps(index_struct_dict)
db.session.commit()
@property
def query_index(self) -> Optional[BaseGPTKeywordTableIndex]:
docstore = DatesetDocumentStore(
dataset=self._dataset,
user_id=self._dataset.created_by,
embedding_model_name="text-embedding-ada-002"
)
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
dataset_keyword_table = self.get_keyword_table()
if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict:
return None
index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict)
return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context)
def get_keyword_table(self):
dataset_keyword_table = self._dataset.dataset_keyword_table
if dataset_keyword_table:
return dataset_keyword_table
return None
def update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()

View File

@@ -0,0 +1,33 @@
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
from core.index.keyword_table_index.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
)
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results

View File

@@ -0,0 +1,238 @@
import json
from collections import defaultdict
from typing import Any, List, Optional, Dict
from langchain.schema import Document, BaseRetriever
from pydantic import BaseModel, Field, Extra
from core.index.base import BaseIndex
from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DatasetKeywordTable
class KeywordTableConfig(BaseModel):
max_keywords_per_chunk: int = 10
class KeywordTableIndex(BaseIndex):
def __init__(self, dataset: Dataset, config: KeywordTableConfig = KeywordTableConfig()):
super().__init__(dataset)
self._config = config
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def delete_by_document_id(self, document_id: str):
# get segment ids by document_id
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.document_id == document_id
).all()
ids = [segment.id for segment in segments]
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
return KeywordTableRetriever(index=self, **kwargs)
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
k = search_kwargs.get('k') if search_kwargs.get('k') else 4
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
return documents
def delete(self) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
else:
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
return {}
def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict:
for keyword in keywords:
if keyword not in keyword_table:
keyword_table[keyword] = set()
keyword_table[keyword].add(id)
return keyword_table
def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> dict:
# get set of ids that correspond to node
node_idxs_to_delete = set(ids)
# delete node_idxs from keyword to node idxs mapping
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
for keyword in keywords_to_delete:
del keyword_table[keyword]
return keyword_table
def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4):
keyword_table_handler = JiebaKeywordTableHandler()
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
sorted_chunk_indices = sorted(
list(chunk_indices_count.keys()),
key=lambda x: chunk_indices_count[x],
reverse=True,
)
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, node_id: str, keywords: List[str]):
document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first()
if document_segment:
document_segment.keywords = keywords
db.session.commit()
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")
class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)

View File

@@ -1,79 +0,0 @@
from typing import (
Any,
Dict,
Optional, Sequence,
)
from llama_index.indices.response.response_synthesis import ResponseSynthesizer
from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from llama_index.types import RESPONSE_TEXT_TYPE
class EnhanceResponseSynthesizer(ResponseSynthesizer):
@classmethod
def from_args(
cls,
service_context: ServiceContext,
streaming: bool = False,
use_async: bool = False,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_mode: ResponseMode = ResponseMode.DEFAULT,
response_kwargs: Optional[Dict] = None,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
) -> "ResponseSynthesizer":
response_builder: Optional[BaseResponseBuilder] = None
if response_mode != ResponseMode.NO_TEXT:
if response_mode == 'no_synthesizer':
response_builder = NoSynthesizer(
service_context=service_context,
simple_template=simple_template,
streaming=streaming,
)
else:
response_builder = get_response_builder(
service_context,
text_qa_template,
refine_template,
simple_template,
response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(response_builder, response_mode, response_kwargs, optimizer)
class NoSynthesizer(BaseResponseBuilder):
def __init__(
self,
service_context: ServiceContext,
simple_template: Optional[SimpleInputPrompt] = None,
streaming: bool = False,
) -> None:
super().__init__(service_context, streaming)
async def aget_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)
def get_response(
self,
query_str: str,
text_chunks: Sequence[str],
prev_response: Optional[str] = None,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return "\n".join(text_chunks)

View File

@@ -1,22 +0,0 @@
from pathlib import Path
from typing import Dict
from bs4 import BeautifulSoup
from llama_index.readers.file.base_parser import BaseParser
class HTMLParser(BaseParser):
"""HTML parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
with open(file, "rb") as fp:
soup = BeautifulSoup(fp, 'html.parser')
text = soup.get_text()
text = text.strip() if text else ''
return text

View File

@@ -1,111 +0,0 @@
"""Markdown parser.
Contains parser for md files.
"""
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from llama_index.readers.file.base_parser import BaseParser
class MarkdownParser(BaseParser):
"""Markdown parser.
Extract text from markdown files.
Returns dictionary with keys as headers and values as the text between headers.
"""
def __init__(
self,
*args: Any,
remove_hyperlinks: bool = True,
remove_images: bool = True,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(*args, **kwargs)
self._remove_hyperlinks = remove_hyperlinks
self._remove_images = remove_images
def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]:
"""Convert a markdown file to a dictionary.
The keys are the headers and the values are the text under each header.
"""
markdown_tups: List[Tuple[Optional[str], str]] = []
lines = markdown_text.split("\n")
current_header = None
current_text = ""
for line in lines:
header_match = re.match(r"^#+\s", line)
if header_match:
if current_header is not None:
markdown_tups.append((current_header, current_text))
current_header = line
current_text = ""
else:
current_text += line + "\n"
markdown_tups.append((current_header, current_text))
if current_header is not None:
# pass linting, assert keys are defined
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]
else:
markdown_tups = [
(key, re.sub("\n", "", value)) for key, value in markdown_tups
]
return markdown_tups
def remove_images(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"!{1}\[\[(.*)\]\]"
content = re.sub(pattern, "", content)
return content
def remove_hyperlinks(self, content: str) -> str:
"""Get a dictionary of a markdown file from its path."""
pattern = r"\[(.*?)\]\((.*?)\)"
content = re.sub(pattern, r"\1", content)
return content
def _init_parser(self) -> Dict:
"""Initialize the parser with the config."""
return {}
def parse_tups(
self, filepath: Path, errors: str = "ignore"
) -> List[Tuple[Optional[str], str]]:
"""Parse file into tuples."""
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
if self._remove_hyperlinks:
content = self.remove_hyperlinks(content)
if self._remove_images:
content = self.remove_images(content)
markdown_tups = self.markdown_to_tups(content)
return markdown_tups
def parse_file(
self, filepath: Path, errors: str = "ignore"
) -> Union[str, List[str]]:
"""Parse file into string."""
tups = self.parse_tups(filepath, errors=errors)
results = []
# TODO: don't include headers right now
for header, value in tups:
if header is None:
results.append(value)
else:
results.append(f"\n\n{header}\n{value}")
return results

View File

@@ -1,56 +0,0 @@
from pathlib import Path
from typing import Dict
from flask import current_app
from llama_index.readers.file.base_parser import BaseParser
from pypdf import PdfReader
from extensions.ext_storage import storage
from models.model import UploadFile
class PDFParser(BaseParser):
"""PDF parser."""
def _init_parser(self) -> Dict:
"""Init parser."""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
if not current_app.config.get('PDF_PREVIEW', True):
return ''
plaintext_file_key = ''
plaintext_file_exists = False
if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']:
upload_file: UploadFile = self._parser_config['upload_file']
if upload_file.hash:
plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext'
try:
text = storage.load(plaintext_file_key).decode('utf-8')
plaintext_file_exists = True
return text
except FileNotFoundError:
pass
text_list = []
with open(file, "rb") as fp:
# Create a PDF object
pdf = PdfReader(fp)
# Get the number of pages in the PDF document
num_pages = len(pdf.pages)
# Iterate over every page
for page in range(num_pages):
# Extract the text from the page
page_text = pdf.pages[page].extract_text()
text_list.append(page_text)
text = "\n".join(text_list)
# save plaintext file for caching
if not plaintext_file_exists and plaintext_file_key:
storage.save(plaintext_file_key, text.encode('utf-8'))
return text

View File

@@ -1,31 +0,0 @@
from pathlib import Path
import json
from typing import Dict
from openpyxl import load_workbook
from llama_index.readers.file.base_parser import BaseParser
from flask import current_app
class XLSXParser(BaseParser):
"""XLSX parser."""
def _init_parser(self) -> Dict:
"""Init parser"""
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> str:
data = []
keys = []
with open(file, "r") as fp:
wb = load_workbook(filename=file, read_only=True)
# loop over all sheets
for sheet in wb:
for row in sheet.iter_rows(values_only=True):
if all(v is None for v in row):
continue
if keys == []:
keys = list(map(str, row))
else:
data.append(json.dumps(dict(zip(keys, list(map(str, row)))), ensure_ascii=False))
return '\n\n'.join(data)

View File

@@ -1,136 +0,0 @@
import json
import logging
from typing import List, Optional
from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type
from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding
class VectorIndex:
def __init__(self, dataset: Dataset):
self._dataset = dataset
def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
if not self._dataset.index_struct_dict:
index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
db.session.commit()
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
if duplicate_check:
nodes = self._filter_duplicate_nodes(index, nodes)
embedding_queue_nodes = []
embedded_nodes = []
for node in nodes:
node_hash = node.doc_hash
# if node hash in cached embedding tables, use cached embedding
embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
if embedding:
node.embedding = embedding.get_embedding()
embedded_nodes.append(node)
else:
embedding_queue_nodes.append(node)
if embedding_queue_nodes:
embedding_results = index._get_node_embedding_results(
embedding_queue_nodes,
set(),
)
# pre embed nodes for cached embedding
for embedding_result in embedding_results:
node = embedding_result.node
node.embedding = embedding_result.embedding
try:
embedding = Embedding(hash=node.doc_hash)
embedding.set_embedding(node.embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
db.session.rollback()
continue
except:
logging.exception('Failed to add embedding to db')
continue
embedded_nodes.append(node)
self.index_insert_nodes(index, embedded_nodes)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
index.insert_nodes(nodes)
def del_nodes(self, node_ids: List[str]):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
for node_id in node_ids:
self.index_delete_node(index, node_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
index.delete_node(node_id)
def del_doc(self, doc_id: str):
if not self._dataset.index_struct_dict:
return
service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
index = vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
self.index_delete_doc(index, doc_id)
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
index.delete(doc_id)
@property
def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
if not self._dataset.index_struct_dict:
return None
service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
return vector_store.get_index(
service_context=service_context,
index_struct=self._dataset.index_struct_dict
)
def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
for node in nodes:
node_id = node.doc_id
exists_duplicate_node = index.exists_by_node_id(node_id)
if exists_duplicate_node:
nodes.remove(node)
return nodes

View File

@@ -0,0 +1,175 @@
import json
import logging
from abc import abstractmethod
from typing import List, Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
class BaseVectorIndex(BaseIndex):
def __init__(self, dataset: Dataset, embeddings: Embeddings):
super().__init__(dataset)
self._embeddings = embeddings
self._vector_store = None
def get_type(self) -> str:
raise NotImplementedError
@abstractmethod
def get_index_name(self, dataset: Dataset) -> str:
raise NotImplementedError
@abstractmethod
def to_index_struct(self) -> dict:
raise NotImplementedError
@abstractmethod
def _get_vector_store(self) -> VectorStore:
raise NotImplementedError
@abstractmethod
def _get_vector_store_class(self) -> type:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
if search_type == 'similarity_score_threshold':
score_threshold = search_kwargs.get("score_threshold")
if (score_threshold is None) or (not isinstance(score_threshold, float)):
search_kwargs['score_threshold'] = .0
docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
query, **search_kwargs
)
docs = []
for doc, similarity in docs_with_similarity:
doc.metadata['score'] = similarity
docs.append(doc)
return docs
# similarity k
# mmr k, fetch_k, lambda_mult
# similarity_score_threshold k
return vector_store.as_retriever(
search_type=search_type,
search_kwargs=search_kwargs
).get_relevant_documents(query)
def get_retriever(self, **kwargs: Any) -> BaseRetriever:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.as_retriever(**kwargs)
def add_texts(self, texts: list[Document], **kwargs):
if self._is_origin():
self.recreate_dataset(self.dataset)
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if kwargs.get('duplicate_check', False):
texts = self._filter_duplicate_texts(texts)
uuids = self._get_uuids(texts)
vector_store.add_documents(texts, uuids=uuids)
def text_exists(self, id: str) -> bool:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
for node_id in ids:
vector_store.del_text(node_id)
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.delete()
def _is_origin(self):
return False
def recreate_dataset(self, dataset: Dataset):
logging.info(f"Recreating dataset {dataset.id}")
try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
origin_index_struct = self.dataset.index_struct[:]
self.dataset.index_struct = None
if documents:
try:
self.create(documents)
except Exception as e:
self.dataset.index_struct = origin_index_struct
raise e
dataset.index_struct = json.dumps(self.to_index_struct())
db.session.commit()
self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@@ -0,0 +1,116 @@
import os
from typing import Optional, Any, List, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
}
class QdrantVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: QdrantConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client_config = config
def get_type(self) -> str:
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
return self.dataset.index_struct_dict['vector_store']['collection_name']
dataset_id = dataset.id
return "Index_" + dataset_id.replace("-", "_")
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"collection_name": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='text',
**self._client_config.to_qdrant_params()
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
return QdrantVectorStore(
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='text'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="metadata.document_id",
match=models.MatchValue(value=document_id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['collection_name']
if class_prefix.startswith('Vector_'):
# original class_prefix
return True
return False

View File

@@ -0,0 +1,69 @@
import json
from flask import current_app
from langchain.embeddings.base import Embeddings
from core.index.vector_index.base import BaseVectorIndex
from extensions.ext_database import db
from models.dataset import Dataset, Document
class VectorIndex:
def __init__(self, dataset: Dataset, config: dict, embeddings: Embeddings):
self._dataset = dataset
self._embeddings = embeddings
self._vector_index = self._init_vector_index(dataset, config, embeddings)
def _init_vector_index(self, dataset: Dataset, config: dict, embeddings: Embeddings) -> BaseVectorIndex:
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError(f"Vector store must be specified.")
if vector_type == "weaviate":
from core.index.vector_index.weaviate_vector_index import WeaviateVectorIndex, WeaviateConfig
return WeaviateVectorIndex(
dataset=dataset,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
embeddings=embeddings
)
elif vector_type == "qdrant":
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
return QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def add_texts(self, texts: list[Document], **kwargs):
if not self._dataset.index_struct_dict:
self._vector_index.create(texts, **kwargs)
self._dataset.index_struct = json.dumps(self._vector_index.to_index_struct())
db.session.commit()
return
self._vector_index.add_texts(texts, **kwargs)
def __getattr__(self, name):
if self._vector_index is not None:
method = getattr(self._vector_index, name)
if callable(method):
return method
raise AttributeError(f"'VectorIndex' object has no attribute '{name}'")

View File

@@ -0,0 +1,136 @@
from typing import Optional, cast
import requests
import weaviate
from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel, root_validator
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: WeaviateConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)
def _get_vector_store_class(self) -> type:
return WeaviateVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
return False

View File

@@ -1,35 +1,34 @@
import datetime
import json
import logging
import re
import tempfile
import time
from pathlib import Path
from typing import Optional, List
import uuid
from typing import Optional, List, cast
from flask import current_app
from flask_login import current_user
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from llama_index import SimpleDirectoryReader
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.node_parser import SimpleNodeParser, NodeParser
from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR
from llama_index.readers.file.markdown_parser import MarkdownParser
from core.data_source.notion import NotionPageReader
from core.index.readers.xlsx_parser import XLSXParser
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
from core.docstore.dataset_docstore import DatesetDocumentStore
from core.index.keyword_table_index import KeywordTableIndex
from core.index.readers.html_parser import HTMLParser
from core.index.readers.markdown_parser import MarkdownParser
from core.index.readers.pdf_parser import PDFParser
from core.index.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.index.vector_index import VectorIndex
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.error import ProviderTokenNotInitError
from core.llm.llm_builder import LLMBuilder
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.llm.token_calculator import TokenCalculator
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule
from libs import helper
from models.dataset import Document as DatasetDocument
from models.dataset import Dataset, DocumentSegment, DatasetProcessRule
from models.model import UploadFile
from models.source import DataSourceBinding
@@ -40,135 +39,171 @@ class IndexingRunner:
self.storage = storage
self.embedding_model_name = embedding_model_name
def run(self, documents: List[Document]):
def run(self, dataset_documents: List[DatasetDocument]):
"""Run the indexing process."""
for document in documents:
for dataset_document in dataset_documents:
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# load file
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get splitter
splitter = self._get_splitter(processing_rule)
# split to documents
documents = self._step_split(
text_docs=text_docs,
splitter=splitter,
dataset=dataset,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
text_docs = self._load_data(dataset_document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._step_split(
# split to documents
documents = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
dataset=dataset,
document=document,
dataset_document=dataset_document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def run_in_splitting_status(self, document: Document):
"""Run the indexing process when the index_status is splitting."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
db.session.delete(document_segments)
db.session.commit()
# load file
text_docs = self._load_data(document)
# get the process rule
processing_rule = db.session.query(DatasetProcessRule). \
filter(DatasetProcessRule.id == document.dataset_process_rule_id). \
first()
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# split to nodes
nodes = self._step_split(
text_docs=text_docs,
node_parser=node_parser,
dataset=dataset,
document=document,
processing_rule=processing_rule
)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
def run_in_indexing_status(self, document: Document):
def run_in_indexing_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is indexing."""
# get dataset
dataset = Dataset.query.filter_by(
id=document.dataset_id
).first()
try:
# get dataset
dataset = Dataset.query.filter_by(
id=dataset_document.dataset_id
).first()
if not dataset:
raise ValueError("no dataset found")
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=document.id
).all()
nodes = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
relationships = {
DocumentRelationship.SOURCE: document_segment.document_id,
}
# get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by(
dataset_id=dataset.id,
document_id=dataset_document.id
).all()
previous_segment = document_segment.previous_segment
if previous_segment:
relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
documents = []
if document_segments:
for document_segment in document_segments:
# transform segment to node
if document_segment.status != "completed":
document = Document(
page_content=document_segment.content,
metadata={
"doc_id": document_segment.index_node_id,
"doc_hash": document_segment.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
}
)
next_segment = document_segment.next_segment
if next_segment:
relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
node = Node(
doc_id=document_segment.index_node_id,
doc_hash=document_segment.index_node_hash,
text=document_segment.content,
extra_info=None,
node_info=None,
relationships=relationships
)
nodes.append(node)
documents.append(document)
# build index
self._build_index(
dataset=dataset,
document=document,
nodes=nodes
)
# build index
self._build_index(
dataset=dataset,
dataset_document=dataset_document,
documents=documents
)
except DocumentIsPausedException:
raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
except ProviderTokenNotInitError as e:
dataset_document.indexing_status = 'error'
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict) -> dict:
"""
@@ -179,28 +214,29 @@ class IndexingRunner:
total_segments = 0
for file_detail in file_details:
# load data from file
text_docs = self._load_data_from_file(file_detail)
text_docs = FileExtractor.load(file_detail)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name,
self.filter_string(document.page_content))
return {
"total_segments": total_segments,
@@ -230,35 +266,36 @@ class IndexingRunner:
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
reader = NotionPageReader(integration_token=data_source_binding.access_token)
for page in notion_info['pages']:
if page['type'] == 'page':
page_ids = [page['page_id']]
documents = reader.load_data_as_documents(page_ids=page_ids)
elif page['type'] == 'database':
documents = reader.load_data_as_documents(database_id=page['page_id'])
else:
documents = []
loader = NotionLoader(
notion_access_token=data_source_binding.access_token,
notion_workspace_id=workspace_id,
notion_obj_id=page['page_id'],
notion_page_type=page['type']
)
documents = loader.load()
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"],
rules=json.dumps(tmp_processing_rule["rules"])
)
# get node parser for splitting
node_parser = self._get_node_parser(processing_rule)
# get splitter
splitter = self._get_splitter(processing_rule)
# split to nodes
nodes = self._split_to_nodes(
# split to documents
documents = self._split_to_documents(
text_docs=documents,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
total_segments += len(nodes)
for node in nodes:
total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(node.get_text())
preview_texts.append(document.page_content)
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text())
tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
return {
"total_segments": total_segments,
@@ -268,14 +305,14 @@ class IndexingRunner:
"preview": preview_texts
}
def _load_data(self, document: Document) -> List[Document]:
def _load_data(self, dataset_document: DatasetDocument) -> List[Document]:
# load file
if document.data_source_type not in ["upload_file", "notion_import"]:
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
return []
data_source_info = document.data_source_info_dict
data_source_info = dataset_document.data_source_info_dict
text_docs = []
if document.data_source_type == 'upload_file':
if dataset_document.data_source_type == 'upload_file':
if not data_source_info or 'upload_file_id' not in data_source_info:
raise ValueError("no upload file found")
@@ -283,109 +320,38 @@ class IndexingRunner:
filter(UploadFile.id == data_source_info['upload_file_id']). \
one_or_none()
text_docs = self._load_data_from_file(file_detail)
elif document.data_source_type == 'notion_import':
if not data_source_info or 'notion_page_id' not in data_source_info \
or 'notion_workspace_id' not in data_source_info:
raise ValueError("no notion page found")
workspace_id = data_source_info['notion_workspace_id']
page_id = data_source_info['notion_page_id']
page_type = data_source_info['type']
data_source_binding = DataSourceBinding.query.filter(
db.and_(
DataSourceBinding.tenant_id == document.tenant_id,
DataSourceBinding.provider == 'notion',
DataSourceBinding.disabled == False,
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
)
).first()
if not data_source_binding:
raise ValueError('Data source binding not found.')
if page_type == 'page':
# add page last_edited_time to data_source_info
self._get_notion_page_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_page_data_from_notion(page_id, data_source_binding.access_token)
elif page_type == 'database':
# add page last_edited_time to data_source_info
self._get_notion_database_last_edited_time(page_id, data_source_binding.access_token, document)
text_docs = self._load_database_data_from_notion(page_id, data_source_binding.access_token)
text_docs = FileExtractor.load(file_detail)
elif dataset_document.data_source_type == 'notion_import':
loader = NotionLoader.from_document(dataset_document)
text_docs = loader.load()
# update document status to splitting
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="splitting",
extra_update_params={
Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]),
Document.parsing_completed_at: datetime.datetime.utcnow()
DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
}
)
# replace doc id to document model id
text_docs = cast(List[Document], text_docs)
for text_doc in text_docs:
# remove invalid symbol
text_doc.text = self.filter_string(text_doc.get_text())
text_doc.doc_id = document.id
text_doc.page_content = self.filter_string(text_doc.page_content)
text_doc.metadata['document_id'] = dataset_document.id
text_doc.metadata['dataset_id'] = dataset_document.dataset_id
return text_docs
def filter_string(self, text):
pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]')
return pattern.sub('', text)
text = re.sub(r'<\|', '<', text)
text = re.sub(r'\|>', '>', text)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]', '', text)
return text
def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]:
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
self.storage.download(upload_file.key, filepath)
file_extractor = DEFAULT_FILE_EXTRACTOR.copy()
file_extractor[".markdown"] = MarkdownParser()
file_extractor[".md"] = MarkdownParser()
file_extractor[".html"] = HTMLParser()
file_extractor[".htm"] = HTMLParser()
file_extractor[".pdf"] = PDFParser({'upload_file': upload_file})
file_extractor[".xlsx"] = XLSXParser()
loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor)
text_docs = loader.load_data()
return text_docs
def _load_page_data_from_notion(self, page_id: str, access_token: str) -> List[Document]:
page_ids = [page_id]
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(page_ids=page_ids)
return text_docs
def _load_database_data_from_notion(self, database_id: str, access_token: str) -> List[Document]:
reader = NotionPageReader(integration_token=access_token)
text_docs = reader.load_data_as_documents(database_id=database_id)
return text_docs
def _get_notion_page_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_page_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_notion_database_last_edited_time(self, page_id: str, access_token: str, document: Document):
reader = NotionPageReader(integration_token=access_token)
last_edited_time = reader.get_database_last_edited_time(page_id)
data_source_info = document.data_source_info_dict
data_source_info['last_edited_time'] = last_edited_time
update_params = {
Document.data_source_info: json.dumps(data_source_info)
}
Document.query.filter_by(id=document.id).update(update_params)
db.session.commit()
def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser:
def _get_splitter(self, processing_rule: DatasetProcessRule) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
@@ -414,68 +380,83 @@ class IndexingRunner:
separators=["\n\n", "", ".", " ", ""]
)
return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True)
return character_splitter
def _step_split(self, text_docs: List[Document], node_parser: NodeParser,
dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]:
def _step_split(self, text_docs: List[Document], splitter: TextSplitter,
dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
-> List[Document]:
"""
Split the text documents into nodes and save them to the document segment.
Split the text documents into documents and save them to the document segment.
"""
nodes = self._split_to_nodes(
documents = self._split_to_documents(
text_docs=text_docs,
node_parser=node_parser,
splitter=splitter,
processing_rule=processing_rule
)
# save node to document segment
doc_store = DatesetDocumentStore(
dataset=dataset,
user_id=document.created_by,
user_id=dataset_document.created_by,
embedding_model_name=self.embedding_model_name,
document_id=document.id
document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(nodes)
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.utcnow()
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
Document.cleaning_completed_at: cur_time,
Document.splitting_completed_at: cur_time,
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
}
)
# update segment status to indexing
self._update_segments_by_document(
document_id=document.id,
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.utcnow()
}
)
return nodes
return documents
def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser,
processing_rule: DatasetProcessRule) -> List[Node]:
def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
Split the text documents into nodes.
"""
all_nodes = []
all_documents = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.get_text(), processing_rule)
text_doc.text = document_text
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
nodes = node_parser.get_nodes_from_documents([text_doc])
nodes = [node for node in nodes if node.text is not None and node.text.strip()]
all_nodes.extend(nodes)
documents = splitter.split_documents([text_doc])
return all_nodes
split_documents = []
for document in documents:
if document.page_content is None or not document.page_content.strip():
continue
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata['doc_id'] = doc_id
document.metadata['doc_hash'] = hash
split_documents.append(document)
all_documents.extend(split_documents)
return all_documents
def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
"""
@@ -506,37 +487,38 @@ class IndexingRunner:
return text
def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None:
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
"""
Build the index for the document.
"""
vector_index = VectorIndex(dataset=dataset)
keyword_table_index = KeywordTableIndex(dataset=dataset)
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 100
for i in range(0, len(nodes), chunk_size):
for i in range(0, len(documents), chunk_size):
# check document is paused
self._check_document_paused_status(document.id)
chunk_nodes = nodes[i:i + chunk_size]
self._check_document_paused_status(dataset_document.id)
chunk_documents = documents[i:i + chunk_size]
tokens += sum(
TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes
TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content)
for document in chunk_documents
)
# save vector index
if dataset.indexing_technique == "high_quality":
vector_index.add_nodes(chunk_nodes)
if vector_index:
vector_index.add_texts(chunk_documents)
# save keyword index
keyword_table_index.add_nodes(chunk_nodes)
keyword_table_index.add_texts(chunk_documents)
node_ids = [node.doc_id for node in chunk_nodes]
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document.id,
DocumentSegment.index_node_id.in_(node_ids),
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == "indexing"
).update({
DocumentSegment.status: "completed",
@@ -549,12 +531,12 @@ class IndexingRunner:
# update document status to completed
self._update_document_index_status(
document_id=document.id,
document_id=dataset_document.id,
after_indexing_status="completed",
extra_update_params={
Document.tokens: tokens,
Document.completed_at: datetime.datetime.utcnow(),
Document.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.utcnow(),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
}
)
@@ -569,25 +551,25 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = Document.query.filter_by(id=document_id, is_paused=True).count()
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
update_params = {
Document.indexing_status: after_indexing_status
DatasetDocument.indexing_status: after_indexing_status
}
if extra_update_params:
update_params.update(extra_update_params)
Document.query.filter_by(id=document_id).update(update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params)
db.session.commit()
def _update_segments_by_document(self, document_id: str, update_params: dict) -> None:
def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
"""
Update the document segment by document id.
"""
DocumentSegment.query.filter_by(document_id=document_id).update(update_params)
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()

View File

@@ -17,14 +17,16 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(str(e))
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(str(e))
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
@@ -39,13 +41,15 @@ def handle_llm_exceptions_async(func):
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(str(e))
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(str(e))
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@@ -1,7 +1,6 @@
from typing import Union, Optional
from typing import Union, Optional, List
from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
@@ -32,12 +31,11 @@ class LLMBuilder:
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
if model_name == 'fake':
return FakeListLLM(responses=[])
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
@@ -52,16 +50,21 @@ class LLMBuilder:
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
model_kwargs = {
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
@@ -69,7 +72,7 @@ class LLMBuilder:
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
callbacks=callbacks
)
@classmethod

View File

@@ -42,7 +42,11 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id.replace('.', '') if model_id else None
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):

View File

@@ -1,3 +1,4 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop)
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop)
return await super().agenerate(messages, stop, callbacks, **kwargs)

View File

@@ -1,5 +1,4 @@
import os
from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop)
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop)
return await super().agenerate(prompts, stop, callbacks, **kwargs)

View File

@@ -1,6 +1,7 @@
import os
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop)
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop)
return await super().agenerate(messages, stop, callbacks, **kwargs)

View File

@@ -1,5 +1,6 @@
import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop)
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop)
return await super().agenerate(prompts, stop, callbacks, **kwargs)

25
api/core/llm/whisper.py Normal file
View File

@@ -0,0 +1,25 @@
import openai
from models.provider import ProviderName
from core.llm.error_handle_wraps import handle_llm_exceptions
from core.llm.provider.base import BaseProvider
class Whisper:
def __init__(self, provider: BaseProvider):
self.provider = provider
if self.provider.get_provider_name() == ProviderName.OPENAI:
self.client = openai.Audio
self.credentials = provider.get_credentials()
@handle_llm_exceptions
def transcribe(self, file):
return self.client.transcribe(
model='whisper-1',
file=file,
api_key=self.credentials.get('openai_api_key'),
api_base=self.credentials.get('openai_api_base'),
api_type=self.credentials.get('openai_api_type'),
api_version=self.credentials.get('openai_api_version'),
)

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel
from langchain.schema import get_buffer_string
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory

View File

@@ -3,13 +3,13 @@ import re
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import OutLinePromptTemplate
from core.prompt.prompt_template import JinjaPromptTemplate
class PromptBuilder:
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
@@ -17,7 +17,7 @@ class PromptBuilder:
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
@@ -25,13 +25,14 @@ class PromptBuilder:
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = OutLinePromptTemplate.from_template(prompt_content)
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template

View File

@@ -1,10 +1,34 @@
import re
from typing import Any
from jinja2 import Environment, meta
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
class JinjaPromptTemplate(PromptTemplate):
template_format: str = "jinja2"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast)
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
@@ -16,6 +40,24 @@ class OutLinePromptTemplate(PromptTemplate):
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):

View File

@@ -1,17 +1,15 @@
from llama_index import QueryKeywordExtractPrompt
CONVERSATION_TITLE_PROMPT = (
"Human:{query}\n-----\n"
"Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
"If the human said is conducted in Chinese, you should return a Chinese title.\n"
"If the human said is conducted in English, you should return an English title.\n"
"If what the human said is conducted in English, you should only return an English title.\n"
"If what the human said is conducted in Chinese, you should only return a Chinese title.\n"
"title:"
)
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the conversation communicating in Chinese, you should return a Chinese summary.\n"
"If the conversation communicating in English, you should return an English summary.\n"
"If the following conversation communicating in English, you should only return an English summary.\n"
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
@@ -45,23 +43,6 @@ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"[\"question1\",\"question2\",\"question3\"]\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = (
"A question is provided below. Given the question, extract up to {max_keywords} "
"keywords from the text. Focus on extracting the keywords that we can use "
"to best lookup answers to the question. Avoid stopwords."
"I am not sure which language the following question is in. "
"If the user asked the question in Chinese, please return the keywords in Chinese. "
"If the user asked the question in English, please return the keywords in English.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"Provide keywords in the following comma-separated format: 'KEYWORDS: <keywords>'\n"
)
QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt(
QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL
)
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
the model prompt that best suits the input.
You will be provided with the prompt, variables, and an opening statement.

View File

@@ -0,0 +1,87 @@
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from langchain.tools import BaseTool
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.embedding.cached_embedding import CacheEmbedding
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
from core.index.vector_index.vector_index import VectorIndex
from core.llm.llm_builder import LLMBuilder
from models.dataset import Dataset
class DatasetTool(BaseTool):
"""Tool for querying a Dataset."""
dataset: Dataset
k: int = 2
def _run(self, tool_input: str) -> str:
if self.dataset.indexing_technique == "economy":
# use keyword table query
kw_table_index = KeywordTableIndex(
dataset=self.dataset,
config=KeywordTableConfig(
max_keywords_per_chunk=5
)
)
documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k})
else:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = vector_index.search(
tool_input,
search_type='similarity',
search_kwargs={
'k': self.k
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))
async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
model_name='text-embedding-ada-002'
)
embeddings = CacheEmbedding(OpenAIEmbeddings(
**model_credentials
))
vector_index = VectorIndex(
dataset=self.dataset,
config=current_app.config,
embeddings=embeddings
)
documents = await vector_index.asearch(
tool_input,
search_type='similarity',
search_kwargs={
'k': 10
}
)
hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id)
hit_callback.on_tool_end(documents)
return str("\n".join([document.page_content for document in documents]))

View File

@@ -1,73 +0,0 @@
from typing import Optional
from langchain.callbacks import CallbackManager
from llama_index.langchain_helpers.agents import IndexToolConfig
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.index.keyword_table_index import KeywordTableIndex
from core.index.vector_index import VectorIndex
from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE
from core.tool.llama_index_tool import EnhanceLlamaIndexTool
from models.dataset import Dataset
class DatasetToolBuilder:
@classmethod
def build_dataset_tool(cls, dataset: Dataset,
response_mode: str = "no_synthesizer",
callback_handler: Optional[DatasetToolCallbackHandler] = None):
if dataset.indexing_technique == "economy":
# use keyword table query
index = KeywordTableIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
"query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE,
"max_keywords_per_query": 5,
# If num_chunks_per_query is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"num_chunks_per_query": 2
}
else:
index = VectorIndex(dataset=dataset).query_index
if not index:
return None
query_kwargs = {
"mode": "default",
"response_mode": response_mode,
# If top_k is too large,
# it will slow down the synthesis process due to multiple iterations of refinement.
"similarity_top_k": 2
}
# fulfill description when it is empty
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
index_tool_config = IndexToolConfig(
index=index,
name=f"dataset-{dataset.id}",
description=description,
index_query_kwargs=query_kwargs,
tool_kwargs={
"callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()])
},
# tool_kwargs={"return_direct": True},
# return_direct: Whether to return LLM results directly or process the output data with an Output Parser
)
index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset.id)
return EnhanceLlamaIndexTool.from_tool_config(
tool_config=index_tool_config,
callback_handler=index_callback_handler
)

View File

@@ -1,43 +0,0 @@
from typing import Dict
from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field
from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
class EnhanceLlamaIndexTool(BaseTool):
"""Tool for querying a LlamaIndex."""
# NOTE: name/description still needs to be set
index: BaseGPTIndex
query_kwargs: Dict = Field(default_factory=dict)
return_sources: bool = False
callback_handler: IndexToolCallbackHandler
@classmethod
def from_tool_config(cls, tool_config: IndexToolConfig,
callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
"""Create a tool from a tool config."""
return_sources = tool_config.tool_kwargs.pop("return_sources", False)
return cls(
index=tool_config.index,
callback_handler=callback_handler,
name=tool_config.name,
description=tool_config.description,
return_sources=return_sources,
query_kwargs=tool_config.index_query_kwargs,
**tool_config.tool_kwargs,
)
def _run(self, tool_input: str) -> str:
response = self.index.query(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)
async def _arun(self, tool_input: str) -> str:
response = await self.index.aquery(tool_input, **self.query_kwargs)
self.callback_handler.on_tool_end(response)
return str(response)

View File

@@ -1,34 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from llama_index import ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs import Node
from llama_index.vector_stores.types import VectorStore
class BaseVectorStoreClient(ABC):
@abstractmethod
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
raise NotImplementedError
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
def delete_node(self, node_id: str):
self._vector_store.delete_node(node_id)
def exists_by_node_id(self, node_id: str) -> bool:
return self._vector_store.exists_by_node_id(node_id)
class EnhanceVectorStore(ABC):
@abstractmethod
def delete_node(self, node_id: str):
pass
@abstractmethod
def exists_by_node_id(self, node_id: str) -> bool:
pass

View File

@@ -0,0 +1,69 @@
from typing import cast, Any
from langchain.schema import Document
from langchain.vectorstores import Qdrant
from qdrant_client.http.models import Filter, PointIdsList, FilterSelector
from qdrant_client.local.qdrant_local import QdrantLocal
class QdrantVectorStore(Qdrant):
def del_texts(self, filter: Filter):
if not filter:
raise ValueError('filter must not be empty')
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def del_text(self, uuid: str) -> None:
self._reload_if_needed()
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(
points=[uuid],
),
)
def text_exists(self, uuid: str) -> bool:
self._reload_if_needed()
response = self.client.retrieve(
collection_name=self.collection_name,
ids=[uuid]
)
return len(response) > 0
def delete(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
if scored_point.payload.get('doc_id'):
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata={'doc_id': scored_point.id}
)
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)
def _reload_if_needed(self):
if isinstance(self.client, QdrantLocal):
self.client = cast(QdrantLocal, self.client)
self.client._load()

View File

@@ -1,147 +0,0 @@
import os
from typing import cast, List
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from qdrant_client.http.models import Payload, Filter
import qdrant_client
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
from llama_index.vector_stores import QdrantVectorStore
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class QdrantVectorStoreClient(BaseVectorStoreClient):
def __init__(self, url: str, api_key: str, root_path: str):
self._client = self.init_from_config(url, api_key, root_path)
@classmethod
def init_from_config(cls, url: str, api_key: str, root_path: str):
if url and url.startswith('path:'):
path = url.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(root_path, path)
return qdrant_client.QdrantClient(
path=path
)
else:
return qdrant_client.QdrantClient(
url=url,
api_key=api_key,
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = QdrantIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"collection_name": "Gpt_index_xxx"}
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTQdrantEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=QdrantEnhanceVectorStore(
client=self._client,
collection_name=collection_name
)
)
def to_index_config(self, index_id: str) -> dict:
return {"collection_name": index_id}
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
pass
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
from qdrant_client.http import models as rest
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=rest.Filter(
must=[
rest.FieldCondition(
key="id", match=rest.MatchValue(value=node_id)
)
]
),
)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
self._reload_if_needed()
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[node_id]
)
return len(response) > 0
def query(
self,
query: VectorStoreQuery,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
self._reload_if_needed()
response = self._client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=cast(int, query.similarity_top_k),
query_filter=cast(Filter, self._build_query_filter(query)),
with_vectors=True
)
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
node = Node(
doc_id=str(point.id),
text=payload.get("text"),
embedding=point.vector,
extra_info=payload.get("extra_info"),
relationships={
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
},
)
nodes.append(node)
similarities.append(point.score)
ids.append(str(point.id))
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _reload_if_needed(self):
if isinstance(self._client._client, QdrantLocal):
self._client._client._load()

View File

@@ -1,62 +0,0 @@
from flask import Flask
from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
class VectorStore:
def __init__(self):
self._vector_store = None
self._client = None
def init_app(self, app: Flask):
if not app.config['VECTOR_STORE']:
return
self._vector_store = app.config['VECTOR_STORE']
if self._vector_store not in SUPPORTED_VECTOR_STORES:
raise ValueError(f"Vector store {self._vector_store} is not supported.")
if self._vector_store == 'weaviate':
self._client = WeaviateVectorStoreClient(
endpoint=app.config['WEAVIATE_ENDPOINT'],
api_key=app.config['WEAVIATE_API_KEY'],
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'],
batch_size=app.config['WEAVIATE_BATCH_SIZE']
)
elif self._vector_store == 'qdrant':
self._client = QdrantVectorStoreClient(
url=app.config['QDRANT_URL'],
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
app.extensions['vector_store'] = self
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
vector_store_config: dict = index_struct.get('vector_store')
index = self.get_client().get_index(
service_context=service_context,
config=vector_store_config
)
return index
def to_index_struct(self, index_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
}
def get_client(self):
if not self._client:
raise Exception("Vector store client is not initialized.")
return self._client

View File

@@ -1,66 +0,0 @@
from llama_index.indices.query.base import IS
from typing import (
Any,
Dict,
List,
Optional
)
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)

View File

@@ -0,0 +1,38 @@
from langchain.vectorstores import Weaviate
class WeaviateVectorStore(Weaviate):
def del_texts(self, where_filter: dict):
if not where_filter:
raise ValueError('where_filter must not be empty')
self._client.batch.delete_objects(
class_name=self._index_name,
where=where_filter,
output='minimal'
)
def del_text(self, uuid: str) -> None:
self._client.data_object.delete(
uuid,
class_name=self._index_name
)
def text_exists(self, uuid: str) -> bool:
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": uuid,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][self._index_name]
if len(entries) == 0:
return False
return True
def delete(self):
self._client.schema.delete_class(self._index_name)

View File

@@ -1,270 +0,0 @@
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
)
class WeaviateVectorStoreClient(BaseVectorStoreClient):
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
self._client = self.init_from_config(endpoint, api_key, grpc_enabled, batch_size)
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool, batch_size: int):
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
weaviate.connect.connection.has_grpc = grpc_enabled
client = weaviate.Client(
url=endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = WeaviateIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"class_prefix": "Gpt_index_xxx"}
class_prefix = config.get('class_prefix')
if not class_prefix:
raise Exception("class_prefix cannot be None.")
return GPTWeaviateEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=WeaviateWithSimilaritiesVectorStore(
weaviate_client=self._client,
class_prefix=class_prefix
)
)
def to_index_config(self, index_id: str) -> dict:
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
self._client,
self._class_prefix,
query,
)
nodes = nodes[: query.similarity_top_k]
node_idxs = [str(i) for i in range(len(nodes))]
similarities = []
for node in nodes:
similarities.append(node.extra_info['similarity'])
del node.extra_info['similarity']
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
def weaviate_query(
self,
client: Any,
class_prefix: str,
query_spec: VectorStoreQuery,
) -> List[Node]:
"""Convert to LlamaIndex list."""
validate_client(client)
class_name = _class_name(class_prefix)
prop_names = [p["name"] for p in NODE_SCHEMA]
vector = query_spec.query_embedding
# build query
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
_logger.debug("Using vector search")
if vector is not None:
query = query.with_near_vector(
{
"vector": vector,
}
)
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
query = query.with_hybrid(
query=query_spec.query_str,
alpha=query_spec.alpha,
vector=vector,
)
query = query.with_limit(query_spec.similarity_top_k)
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
# execute query
query_result = query.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
results = [self._to_node(entry) for entry in entries]
return results
def _to_node(self, entry: Dict) -> Node:
"""Convert to Node."""
extra_info_str = entry["extra_info"]
if extra_info_str == "":
extra_info = None
else:
extra_info = json.loads(extra_info_str)
if 'certainty' in entry['_additional']:
if extra_info:
extra_info['similarity'] = entry['_additional']['certainty']
else:
extra_info = {'similarity': entry['_additional']['certainty']}
node_info_str = entry["node_info"]
if node_info_str == "":
node_info = None
else:
node_info = json.loads(node_info_str)
relationships_str = entry["relationships"]
relationships: Dict[DocumentRelationship, str]
if relationships_str == "":
relationships = field(default_factory=dict)
else:
relationships = {
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
}
return Node(
text=entry["text"],
doc_id=entry["doc_id"],
embedding=entry["_additional"]["vector"],
extra_info=extra_info,
node_info=node_info,
relationships=relationships,
)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document(self._client, doc_id, self._class_prefix)
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node(self._client, node_id, self._class_prefix)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry = get_by_node_id(self._client, node_id, self._class_prefix)
return True if entry else False
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
pass
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueString": ref_doc_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
while len(entries) > 0:
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
if len(entries) == 0:
return None
return entries[0]

View File

@@ -1,7 +0,0 @@
from core.vector_store.vector_store import VectorStore
vector_store = VectorStore()
def init_app(app):
vector_store.init_app(app)

View File

@@ -3,6 +3,7 @@ import re
import subprocess
import uuid
from datetime import datetime
from hashlib import sha256
from zoneinfo import available_timezones
import random
import string
@@ -147,3 +148,8 @@ def get_remote_ip(request):
return request.headers.getlist("X-Forwarded-For")[0]
else:
return request.remote_addr
def generate_text_hash(text: str) -> str:
hash_text = str(text) + 'None'
return sha256(hash_text.encode()).hexdigest()

View File

@@ -0,0 +1,32 @@
"""app config add speech_to_text
Revision ID: a5b56fb053ef
Revises: d3d503a3471c
Create Date: 2023-07-06 17:55:20.894149
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a5b56fb053ef'
down_revision = 'd3d503a3471c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('speech_to_text', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('speech_to_text')
# ### end Alembic commands ###

View File

@@ -0,0 +1,32 @@
"""add is_deleted to conversations
Revision ID: d3d503a3471c
Revises: e32f6ccb87c6
Create Date: 2023-06-27 19:13:30.897981
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd3d503a3471c'
down_revision = 'e32f6ccb87c6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('false'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.drop_column('is_deleted')
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More