mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7dea485d57 | ||
|
|
5b9858a8a3 | ||
|
|
42a5b3ec17 | ||
|
|
2d1cb076c6 | ||
|
|
289c93d081 | ||
|
|
c0fe706597 | ||
|
|
9cba1c8bf4 | ||
|
|
cbf095465c | ||
|
|
c007dbdc13 | ||
|
|
ff493d017b | ||
|
|
7f6ad9653e | ||
|
|
2851a9f04e | ||
|
|
c536f85b2e | ||
|
|
b1352ff8b7 | ||
|
|
cc63c8499f | ||
|
|
f191b8b8d1 | ||
|
|
5003db987d | ||
|
|
07aab5e868 | ||
|
|
875dfbbf0e | ||
|
|
9e7efa45d4 | ||
|
|
8bf892b306 | ||
|
|
8480b0197b | ||
|
|
df07fb5951 | ||
|
|
4ab4bcc074 | ||
|
|
1d4f019de4 | ||
|
|
677aacc8e3 | ||
|
|
fda937175d | ||
|
|
024250803a | ||
|
|
b711ce33b7 | ||
|
|
52bec63275 | ||
|
|
657fa80f4d | ||
|
|
373e90ee6d | ||
|
|
41d4c5b424 | ||
|
|
86a9dea428 | ||
|
|
8606d80c66 | ||
|
|
5bffa1d918 | ||
|
|
c9b0fe47bf | ||
|
|
bcd744b6b7 | ||
|
|
5e511e01bf | ||
|
|
52291c645e | ||
|
|
a31466d34e | ||
|
|
d38eac959b | ||
|
|
9dbb8acd4b | ||
|
|
46154c6705 | ||
|
|
54ff03c35d | ||
|
|
18c710c906 | ||
|
|
59236b789f | ||
|
|
fd3d43cae1 | ||
|
|
8eae643911 | ||
|
|
fd9413874a | ||
|
|
227f9fb77d | ||
|
|
c40ee7e629 | ||
|
|
841e967d48 | ||
|
|
9df0dcedae | ||
|
|
724e053732 | ||
|
|
e409895c02 | ||
|
|
32d9b6181c | ||
|
|
2b018fade2 | ||
|
|
e65f9cb17a | ||
|
|
1367f34398 | ||
|
|
e47f6b879a |
@@ -1,11 +1,8 @@
|
||||
FROM mcr.microsoft.com/devcontainers/anaconda:0-3
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.10
|
||||
|
||||
COPY . .
|
||||
|
||||
# Copy environment.yml (if found) to a temp location so we update the environment. Also
|
||||
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
|
||||
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
|
||||
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
|
||||
&& rm -rf /tmp/conda-tmp
|
||||
|
||||
# [Optional] Uncomment this section to install additional OS packages.
|
||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
@@ -1,13 +1,12 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
|
||||
{
|
||||
"name": "Anaconda (Python 3)",
|
||||
"name": "Python 3.10",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/dhoeric/features/act:1": {},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
|
||||
11
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
Normal file
11
.github/ISSUE_TEMPLATE/help_wanted.yml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
name: "🤝 Help Wanted"
|
||||
description: "Request help from the community"
|
||||
labels:
|
||||
- help-wanted
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Provide a description of the help you need
|
||||
placeholder: Briefly describe what you need help with.
|
||||
validations:
|
||||
required: true
|
||||
2
.github/workflows/build-api-image.yml
vendored
2
.github/workflows/build-api-image.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-api
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
2
.github/workflows/build-web-image.yml
vendored
2
.github/workflows/build-web-image.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-web
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
37
.github/workflows/check_no_chinese_comments.py
vendored
37
.github/workflows/check_no_chinese_comments.py
vendored
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
def has_chinese_characters(text):
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_file_for_chinese_comments(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
if has_chinese_characters(line):
|
||||
print(f"Found Chinese characters in {file_path} on line {line_number}:")
|
||||
print(line.strip())
|
||||
return True
|
||||
return False
|
||||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
|
||||
'prompts.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
if file.endswith(".py") and file not in excluded_files:
|
||||
file_path = os.path.join(root, file)
|
||||
if check_file_for_chinese_comments(file_path):
|
||||
has_chinese = True
|
||||
|
||||
if has_chinese:
|
||||
raise Exception("Found Chinese characters in Python files. Please remove them.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
.github/workflows/check_no_chinese_comments.yml
vendored
31
.github/workflows/check_no_chinese_comments.yml
vendored
@@ -1,31 +0,0 @@
|
||||
name: Check for Chinese comments
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check-chinese-comments:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install zhon
|
||||
|
||||
- name: Run script to check for Chinese comments
|
||||
run: |
|
||||
python .github/workflows/check_no_chinese_comments.py
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -144,6 +144,7 @@ docker/volumes/app/storage/*
|
||||
docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
@@ -50,25 +50,7 @@ S3_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Cookie configuration
|
||||
COOKIE_HTTPONLY=true
|
||||
COOKIE_SAMESITE=None
|
||||
COOKIE_SECURE=true
|
||||
|
||||
# Session configuration
|
||||
SESSION_PERMANENT=true
|
||||
SESSION_USE_SIGNER=true
|
||||
|
||||
## support redis, sqlalchemy
|
||||
SESSION_TYPE=redis
|
||||
|
||||
# session redis configuration
|
||||
SESSION_REDIS_HOST=localhost
|
||||
SESSION_REDIS_PORT=6379
|
||||
SESSION_REDIS_PASSWORD=difyai123456
|
||||
SESSION_REDIS_DB=2
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -77,9 +59,16 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
|
||||
94
api/app.py
94
api/app.py
@@ -1,23 +1,24 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import logging
|
||||
import json
|
||||
import threading
|
||||
|
||||
from flask import Flask, request, Response, session
|
||||
import flask_login
|
||||
from flask import Flask, request, Response
|
||||
from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
@@ -27,12 +28,10 @@ from models import model, account, dataset, web, task, source, tool
|
||||
from events import event_handlers
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
import core
|
||||
from config import Config, CloudEditionConfig
|
||||
from commands import register_commands
|
||||
from models.account import TenantAccountJoin, AccountStatus
|
||||
from models.model import Account, EndUser, App
|
||||
from services.account_service import TenantService
|
||||
from services.account_service import AccountService
|
||||
from libs.passport import PassportService
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
@@ -85,81 +84,33 @@ def initialize_extensions(app):
|
||||
ext_redis.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_session.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_stripe.init_app(app)
|
||||
|
||||
|
||||
def _create_tenant_for_account(account):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
account.current_tenant = tenant
|
||||
|
||||
return tenant
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
"""Load user based on the user_id."""
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint == 'console':
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
if '.' in user_id:
|
||||
tenant_id, account_id = user_id.split('.')
|
||||
else:
|
||||
account_id = user_id
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if ' ' not in auth_header:
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != 'bearer':
|
||||
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get('user_id')
|
||||
|
||||
account = db.session.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if account:
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise Forbidden('Account is banned or closed.')
|
||||
|
||||
workspace_id = session.get('workspace_id')
|
||||
if workspace_id:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == workspace_id
|
||||
).first()
|
||||
|
||||
if not tenant_account_join:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id).first()
|
||||
|
||||
if tenant_account_join:
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
else:
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
else:
|
||||
account.current_tenant_id = workspace_id
|
||||
else:
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.account_id == account.id).first()
|
||||
if tenant_account_join:
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
else:
|
||||
_create_tenant_for_account(account)
|
||||
session['workspace_id'] = account.current_tenant_id
|
||||
|
||||
current_time = datetime.utcnow()
|
||||
|
||||
# update last_active_at when last_active_at is more than 10 minutes ago
|
||||
if current_time - account.last_active_at > timedelta(minutes=10):
|
||||
account.last_active_at = current_time
|
||||
db.session.commit()
|
||||
|
||||
# Log in the user with the updated user_id
|
||||
flask_login.login_user(account, remember=True)
|
||||
|
||||
return account
|
||||
return AccountService.load_user(user_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
@@ -216,6 +167,7 @@ if app.config['TESTING']:
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie('remember_token', '', expires=0)
|
||||
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
|
||||
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
|
||||
return response
|
||||
|
||||
222
api/commands.py
222
api/commands.py
@@ -3,12 +3,13 @@ import json
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
from tqdm import tqdm
|
||||
from flask import current_app
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -456,92 +457,92 @@ def update_qdrant_indexes():
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = 0
|
||||
|
||||
normalization_count = []
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
datasets_result = datasets.items
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if not dataset.collection_binding_id:
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
for i in range(0, len(datasets_result), 5):
|
||||
threads = []
|
||||
sub_datasets = datasets_result[i:i + 5]
|
||||
for dataset in sub_datasets:
|
||||
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'normalization_count': normalization_count
|
||||
})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
original_index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if original_index:
|
||||
original_index.delete_original_collection(dataset, dataset_collection_binding)
|
||||
normalization_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
# index.delete_by_group_id(dataset.id)
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
normalization_count.append(1)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
@@ -646,6 +647,76 @@ def update_app_model_configs(batch_size):
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
@click.command('migrate_default_input_to_dataset_query_variable')
|
||||
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
|
||||
def migrate_default_input_to_dataset_query_variable(batch_size):
|
||||
|
||||
click.secho("Starting...", fg='green')
|
||||
|
||||
total_records = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.count()
|
||||
|
||||
if total_records == 0:
|
||||
click.secho("No data to migrate.", fg='green')
|
||||
return
|
||||
|
||||
num_batches = (total_records + batch_size - 1) // batch_size
|
||||
|
||||
with tqdm(total=total_records, desc="Migrating Data") as pbar:
|
||||
for i in range(num_batches):
|
||||
offset = i * batch_size
|
||||
limit = min(batch_size, total_records - offset)
|
||||
|
||||
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
|
||||
|
||||
data_batch = db.session.query(AppModelConfig) \
|
||||
.join(App, App.app_model_config_id == AppModelConfig.id) \
|
||||
.filter(App.mode == 'completion') \
|
||||
.filter(AppModelConfig.dataset_query_variable == None) \
|
||||
.order_by(App.created_at) \
|
||||
.offset(offset).limit(limit).all()
|
||||
|
||||
if not data_batch:
|
||||
click.secho("No more data to migrate.", fg='green')
|
||||
break
|
||||
|
||||
try:
|
||||
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
|
||||
for data in data_batch:
|
||||
config = AppModelConfig.to_dict(data)
|
||||
|
||||
tools = config["agent_mode"]["tools"]
|
||||
dataset_exists = "dataset" in str(tools)
|
||||
if not dataset_exists:
|
||||
continue
|
||||
|
||||
user_input_form = config.get("user_input_form", [])
|
||||
for form in user_input_form:
|
||||
paragraph = form.get('paragraph')
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'query':
|
||||
data.dataset_query_variable = 'query'
|
||||
break
|
||||
|
||||
if paragraph \
|
||||
and paragraph.get('variable') == 'default_input':
|
||||
data.dataset_query_variable = 'default_input'
|
||||
break
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except Exception as e:
|
||||
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
|
||||
fg='red')
|
||||
continue
|
||||
|
||||
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
|
||||
|
||||
pbar.update(len(data_batch))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
@@ -659,3 +730,4 @@ def register_commands(app):
|
||||
app.cli.add_command(update_qdrant_indexes)
|
||||
app.cli.add_command(update_app_model_configs)
|
||||
app.cli.add_command(normalization_collections)
|
||||
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
|
||||
|
||||
@@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
|
||||
dotenv.load_dotenv()
|
||||
|
||||
DEFAULTS = {
|
||||
'COOKIE_HTTPONLY': 'True',
|
||||
'COOKIE_SECURE': 'True',
|
||||
'COOKIE_SAMESITE': 'None',
|
||||
'DB_USERNAME': 'postgres',
|
||||
'DB_PASSWORD': '',
|
||||
'DB_HOST': 'localhost',
|
||||
@@ -22,10 +19,6 @@ DEFAULTS = {
|
||||
'REDIS_PORT': '6379',
|
||||
'REDIS_DB': '0',
|
||||
'REDIS_USE_SSL': 'False',
|
||||
'SESSION_REDIS_HOST': 'localhost',
|
||||
'SESSION_REDIS_PORT': '6379',
|
||||
'SESSION_REDIS_DB': '2',
|
||||
'SESSION_REDIS_USE_SSL': 'False',
|
||||
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
|
||||
'OAUTH_REDIRECT_INDEX_PATH': '/',
|
||||
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
|
||||
@@ -36,9 +29,6 @@ DEFAULTS = {
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
'SESSION_TYPE': 'sqlalchemy',
|
||||
'SESSION_PERMANENT': 'True',
|
||||
'SESSION_USE_SIGNER': 'True',
|
||||
'DEPLOY_ENV': 'PRODUCTION',
|
||||
'SQLALCHEMY_POOL_SIZE': 30,
|
||||
'SQLALCHEMY_POOL_RECYCLE': 3600,
|
||||
@@ -102,7 +92,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.23"
|
||||
self.CURRENT_VERSION = "0.3.27"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -115,20 +105,6 @@ class Config:
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# cookie settings
|
||||
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
||||
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
|
||||
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
||||
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
|
||||
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
||||
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
|
||||
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
|
||||
|
||||
# session settings, only support sqlalchemy, redis
|
||||
self.SESSION_TYPE = get_env('SESSION_TYPE')
|
||||
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
|
||||
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
|
||||
|
||||
# redis settings
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
@@ -137,14 +113,6 @@ class Config:
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# session redis settings
|
||||
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
|
||||
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
|
||||
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
|
||||
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
|
||||
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
|
||||
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
|
||||
|
||||
# storage settings
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
@@ -167,6 +135,14 @@ class Config:
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
|
||||
@@ -31,6 +31,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@@ -81,6 +82,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@@ -137,10 +139,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n",
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@@ -169,6 +172,13 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@@ -200,6 +210,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
@@ -255,10 +266,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@@ -287,6 +299,13 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@@ -318,6 +337,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
|
||||
@@ -9,7 +9,7 @@ api = ExternalApi(bp)
|
||||
from . import setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
|
||||
26
api/controllers/console/app/advanced_prompt_template.py
Normal file
26
api/controllers/console/app/advanced_prompt_template.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
service = AdvancedPromptTemplateService()
|
||||
return service.get_prompt(args)
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
@@ -3,10 +3,9 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants.model_template import model_templates, demo_model_templates
|
||||
@@ -17,42 +16,13 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||
app_detail_fields_with_site
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@@ -62,35 +32,6 @@ def _get_app(app_id, tenant_id):
|
||||
|
||||
|
||||
class AppListApi(Resource):
|
||||
prompt_config_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
model_config_partial_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -162,7 +103,8 @@ class AppListApi(Resource):
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=model_config_dict
|
||||
config=model_config_dict,
|
||||
mode=args['mode']
|
||||
)
|
||||
|
||||
app = App(
|
||||
@@ -236,18 +178,6 @@ class AppListApi(Resource):
|
||||
|
||||
|
||||
class AppTemplateApi(Resource):
|
||||
template_fields = {
|
||||
'name': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'mode': fields.String,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
'data': fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -266,38 +196,6 @@ class AppTemplateApi(Resource):
|
||||
|
||||
|
||||
class AppApi(Resource):
|
||||
site_fields = {
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean,
|
||||
'app_base_url': fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'site': fields.Nested(site_fields),
|
||||
'api_base_url': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -2,8 +2,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -13,107 +13,14 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.helper import TimestampField, datetime_string, uuid_value
|
||||
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
|
||||
conversation_message_detail_fields, conversation_with_summary_pagination_fields
|
||||
from libs.helper import datetime_string
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAnnotation, Conversation
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
'like': fields.Integer,
|
||||
'dislike': fields.Integer
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'model': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw,
|
||||
}
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]['text'] if value else ''
|
||||
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': MessageTextField,
|
||||
'answer': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(conversation_detail_fields)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_id, conversation_id):
|
||||
app_id = str(app_id)
|
||||
conversation_id = str(conversation_id)
|
||||
@@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
|
||||
|
||||
class ChatConversationApi(Resource):
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
def get(self, app_id):
|
||||
app_id = str(app_id)
|
||||
|
||||
@@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
|
||||
|
||||
|
||||
class ChatConversationDetailApi(Resource):
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class IntroductionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
answer = LLMGenerator.generate_introduction(
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return {'introduction': answer}
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
|
||||
@@ -16,8 +16,9 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
@@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
@@ -366,7 +329,7 @@ class MessageApi(Resource):
|
||||
message_id = str(message_id)
|
||||
|
||||
# get app info
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
@@ -9,7 +8,7 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
@@ -31,7 +30,8 @@ class ModelConfigResource(Resource):
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
|
||||
@@ -1,33 +1,18 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.helper import supported_language
|
||||
from extensions.ext_database import db
|
||||
from models.model import Site
|
||||
|
||||
app_site_fields = {
|
||||
'app_id': fields.String,
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask import request, redirect, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
@@ -45,15 +42,34 @@ class OAuthDataSource(Resource):
|
||||
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
|
||||
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
|
||||
oauth_provider.save_internal_access_token(internal_secret)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
return { 'data': '' }
|
||||
else:
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
return redirect(auth_url)
|
||||
return { 'data': auth_url }, 200
|
||||
|
||||
|
||||
|
||||
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {'error': 'Invalid provider'}, 400
|
||||
if 'code' in request.args:
|
||||
code = request.args.get('code')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
|
||||
|
||||
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@@ -69,12 +85,7 @@ class OAuthDataSourceCallback(Resource):
|
||||
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
|
||||
return {'error': 'OAuth data source process failed'}, 400
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
|
||||
elif 'error' in request.args:
|
||||
error = request.args.get('error')
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
|
||||
else:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@@ -101,4 +112,5 @@ class OAuthDataSourceSync(Resource):
|
||||
|
||||
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
|
||||
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')
|
||||
|
||||
@@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
@@ -37,12 +36,12 @@ class LoginApi(Resource):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
flask_login.login_user(account, remember=args['remember_me'])
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
# todo: return the user info
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return {'result': 'success'}
|
||||
return {'result': 'success', 'data': token}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
|
||||
@@ -2,9 +2,8 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask import request, redirect, current_app
|
||||
from flask_restful import Resource
|
||||
|
||||
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
|
||||
@@ -75,12 +74,11 @@ class OAuthCallback(Resource):
|
||||
account.initialized_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
# login user
|
||||
session.clear()
|
||||
flask_login.login_user(account, remember=True)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
|
||||
token = AccountService.get_account_jwt_token(account)
|
||||
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
|
||||
@@ -2,10 +2,10 @@ import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@@ -14,7 +14,7 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@@ -24,37 +24,6 @@ cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -131,28 +100,6 @@ class DataSourceApi(Resource):
|
||||
|
||||
|
||||
class DataSourceNotionListApi(Resource):
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import request
|
||||
import flask_restful
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
|
||||
from controllers.console.apikey import api_key_list, api_key_fields
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
from controllers.console import api
|
||||
@@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from libs.helper import TimestampField
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Document
|
||||
from models.model import UploadFile
|
||||
from models.model import UploadFile, ApiToken
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'provider': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'app_count': fields.Integer,
|
||||
'document_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
@@ -82,7 +56,8 @@ class DatasetListApi(Resource):
|
||||
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
# if len(valid_model_list) == 0:
|
||||
# raise ProviderNotInitializeError(
|
||||
# f"No Embedding Model available. Please configure a valid provider "
|
||||
@@ -157,7 +132,8 @@ class DatasetApi(Resource):
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
# get valid model list
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
@@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
@@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
|
||||
|
||||
class DatasetRelatedAppListApi(Resource):
|
||||
app_detail_kernel_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||
'total': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
|
||||
|
||||
|
||||
class DatasetIndexingStatusApi(Resource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -400,16 +347,101 @@ class DatasetIndexingStatusApi(Resource):
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
max_keys = 10
|
||||
token_prefix = 'dataset-'
|
||||
resource_type = 'dataset'
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
keys = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
all()
|
||||
return {"items": keys}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = db.session.query(ApiToken). \
|
||||
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
|
||||
count()
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restful.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code='max_keys_exceeded'
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
|
||||
|
||||
class DatasetApiDeleteApi(Resource):
|
||||
resource_type = 'dataset'
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id):
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
|
||||
key = db.session.query(ApiToken). \
|
||||
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id). \
|
||||
first()
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message='API key not found')
|
||||
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
|
||||
else request.host_url.rstrip('/')) + '/v1'
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
|
||||
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
|
||||
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -23,7 +22,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||
LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import TimestampField
|
||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||
dataset_and_document_fields, document_status_fields
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetProcessRule, Dataset
|
||||
from models.dataset import Document, DocumentSegment
|
||||
@@ -32,64 +32,6 @@ from services.dataset_service import DocumentService, DatasetService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
dataset_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
@@ -303,11 +245,6 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
|
||||
class DatasetInitApi(Resource):
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -504,24 +441,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -541,7 +460,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, self.document_status_fields))
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
@@ -549,20 +468,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -586,7 +491,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
return marshal(document, self.document_status_fields)
|
||||
return marshal(document, document_status_fields)
|
||||
|
||||
|
||||
class DocumentDetailApi(DocumentResource):
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
import services
|
||||
@@ -14,48 +14,18 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
import pandas as pd
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField
|
||||
}
|
||||
|
||||
segment_list_response = {
|
||||
'data': fields.List(fields.Nested(segment_fields)),
|
||||
'has_more': fields.Boolean,
|
||||
'limit': fields.Integer
|
||||
}
|
||||
|
||||
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import tempfile
|
||||
import chardet
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from fields.file_fields import upload_config_fields, file_fields
|
||||
|
||||
from services.file_service import FileService
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
@@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileApi(Resource):
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -48,16 +35,6 @@ class FileApi(Resource):
|
||||
'batch_count_limit': batch_count_limit
|
||||
}, 200
|
||||
|
||||
file_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'size': fields.Integer,
|
||||
'extension': fields.String,
|
||||
'mime_type': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -73,45 +50,13 @@ class FileApi(Resource):
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file_content = file.read()
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if file_size > file_size_limit:
|
||||
message = "({file_size} > {file_size_limit})"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
try:
|
||||
upload_file = FileService.upload_file(file)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
text = FileService.get_file_preview(file_id)
|
||||
return {'content': text}
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, fields
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
import services
|
||||
@@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
|
||||
LLMBadRequestError
|
||||
from libs.helper import TimestampField
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'name': fields.String,
|
||||
'doc_type': fields.String,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'document': fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
'segment': fields.Nested(segment_fields),
|
||||
'score': fields.Float,
|
||||
'tsne_position': fields.Raw
|
||||
}
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
|
||||
|
||||
@@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
|
||||
@@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
|
||||
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != 'chat':
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -11,32 +11,10 @@ from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String
|
||||
}
|
||||
|
||||
installed_app_fields = {
|
||||
'id': fields.String,
|
||||
'app': fields.Nested(app_fields),
|
||||
'app_owner_tenant_id': fields.String,
|
||||
'is_pinned': fields.Boolean,
|
||||
'last_used_at': TimestampField,
|
||||
'editable': fields.Boolean,
|
||||
'uninstallable': fields.Boolean,
|
||||
}
|
||||
|
||||
installed_app_list_fields = {
|
||||
'installed_apps': fields.List(fields.Nested(installed_app_fields))
|
||||
}
|
||||
|
||||
|
||||
class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
|
||||
@@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.completion_service import CompletionService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@@ -26,45 +27,6 @@ from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, installed_app):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from functools import wraps
|
||||
|
||||
import flask_login
|
||||
from flask import request, current_app
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
@@ -58,9 +57,6 @@ class SetupApi(Resource):
|
||||
)
|
||||
|
||||
setup()
|
||||
|
||||
# Login
|
||||
flask_login.login_user(account)
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
return {'result': 'success'}, 201
|
||||
|
||||
@@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.universal_chat.wraps import UniversalChatResource
|
||||
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
|
||||
conversation_with_model_config_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class UniversalChatConversationListApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
|
||||
def get(self, universal_app):
|
||||
app_model = universal_app
|
||||
|
||||
@@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
|
||||
|
||||
class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(conversation_with_model_config_fields)
|
||||
def post(self, universal_app, c_id):
|
||||
app_model = universal_app
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@@ -9,4 +9,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from .app import completion, app, conversation, message, audio
|
||||
|
||||
from .dataset import document
|
||||
from .dataset import document, segment, dataset
|
||||
|
||||
@@ -8,25 +8,11 @@ from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
import services
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
|
||||
@@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
@@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
81
api/controllers/service_api/dataset/dataset.py
Normal file
81
api/controllers/service_api/dataset/dataset.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from libs.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for get datasets."""
|
||||
|
||||
def get(self, tenant_id):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
"""Resource for datasets."""
|
||||
|
||||
def post(self, tenant_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
api.add_resource(DatasetApi, '/datasets')
|
||||
|
||||
@@ -1,114 +1,287 @@
|
||||
import datetime
|
||||
import uuid
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import reqparse
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||
DatasetNotInitedError
|
||||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from libs.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.dataset_service import DocumentService
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
def post(self, dataset):
|
||||
"""Create document."""
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('doc_type', type=str, location='json')
|
||||
parser.add_argument('doc_metadata', type=dict, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
raise DatasetNotInitedError("Dataset indexing technique must be set.")
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
doc_type = args.get('doc_type')
|
||||
doc_metadata = args.get('doc_metadata')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise ValueError('Invalid doc_type.')
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, args.get('text'))
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=dataset.tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=args.get('name') + '.txt',
|
||||
size=len(args.get('text')),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
created_by=dataset.created_by,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=True,
|
||||
used_by=dataset.created_by,
|
||||
used_at=datetime.datetime.utcnow()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
document_data = {
|
||||
'data_source': {
|
||||
'type': 'upload_file',
|
||||
'info': [
|
||||
{
|
||||
'upload_file_id': upload_file.id
|
||||
}
|
||||
]
|
||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'data_source_type': 'upload_file',
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=document_data,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule,
|
||||
document_data=args,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
if doc_type and doc_metadata:
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||
|
||||
document.doc_metadata = {}
|
||||
|
||||
for key, value_type in metadata_schema.items():
|
||||
value = doc_metadata.get(key)
|
||||
if value is not None and isinstance(value, value_type):
|
||||
document.doc_metadata[key] = value
|
||||
|
||||
document.doc_type = doc_type
|
||||
document.updated_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
return {'id': document.id}
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentApi(DatasetApiResource):
|
||||
def delete(self, dataset, document_id):
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
if args['text']:
|
||||
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'data_source_type': 'upload_file',
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
args['original_document_id'] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentAddByFileApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
if 'data' in request.form:
|
||||
args = json.loads(request.form['data'])
|
||||
if 'doc_form' not in args:
|
||||
args['doc_form'] = 'text_model'
|
||||
if 'doc_language' not in args:
|
||||
args['doc_language'] = 'English'
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
# save file info
|
||||
file = request.files['file']
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
if 'data' in request.form:
|
||||
args = json.loads(request.form['data'])
|
||||
if 'doc_form' not in args:
|
||||
args['doc_form'] = 'text_model'
|
||||
if 'doc_language' not in args:
|
||||
args['doc_language'] = 'English'
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if 'file' in request.files:
|
||||
# save file info
|
||||
file = request.files['file']
|
||||
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
'file_info_list': {
|
||||
'file_ids': [upload_file.id]
|
||||
}
|
||||
}
|
||||
}
|
||||
args['data_source'] = data_source
|
||||
# validate args
|
||||
args['original_document_id'] = str(document_id)
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
document_data=args,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
|
||||
created_from='api'
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {
|
||||
'document': marshal(document, document_fields),
|
||||
'batch': batch
|
||||
}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# get dataset info
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
@@ -126,8 +299,85 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(DocumentListApi, '/documents')
|
||||
api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
|
||||
class DocumentListApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
query = Document.query.filter_by(
|
||||
dataset_id=str(dataset_id), tenant_id=tenant_id)
|
||||
|
||||
if search:
|
||||
search = f'%{search}%'
|
||||
query = query.filter(Document.name.like(search))
|
||||
|
||||
query = query.order_by(desc(Document.created_at))
|
||||
|
||||
paginated_documents = query.paginate(
|
||||
page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
response = {
|
||||
'data': marshal(documents, document_fields),
|
||||
'has_more': len(documents) == limit,
|
||||
'limit': limit,
|
||||
'total': paginated_documents.total,
|
||||
'page': page
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class DocumentIndexingStatusApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
tenant_id = str(tenant_id)
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# get documents
|
||||
documents = DocumentService.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
raise NotFound('Documents not found.')
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != 're_segment').count()
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
if document.is_paused:
|
||||
document.indexing_status = 'paused'
|
||||
documents_status.append(marshal(document, document_status_fields))
|
||||
data = {
|
||||
'data': documents_status
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
|
||||
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
|
||||
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
|
||||
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
|
||||
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
|
||||
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
|
||||
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')
|
||||
|
||||
@@ -1,20 +1,73 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = 'high_quality_dataset_only'
|
||||
description = "Current operation only supports 'high-quality' datasets."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_initialized'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = 'archived_document_immutable'
|
||||
description = "Cannot operate when document was archived."
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = 'dataset_name_duplicate'
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = 'invalid_action'
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = 'document_already_finished'
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = 'document_indexing'
|
||||
description = "Cannot operate document during indexing."
|
||||
code = 403
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_inited'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 403
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = 'invalid_metadata'
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
201
api/controllers/service_api/dataset/segment.py
Normal file
201
api/controllers/service_api/dataset/segment.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('status', type=str,
|
||||
action='append', default=[], location='args')
|
||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
status_list = args['status']
|
||||
keyword = args['keyword']
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
||||
|
||||
total = query.count()
|
||||
segments = query.order_by(DocumentSegment.position).all()
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check segment
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
@@ -2,12 +2,14 @@
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask import request, current_app
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from libs.login import _get_user
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.account import Tenant, TenantAccountJoin, Account
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
||||
@@ -43,12 +45,24 @@ def validate_dataset_token(view=None):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('dataset')
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound()
|
||||
|
||||
return view(dataset, *args, **kwargs)
|
||||
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role == 'owner') \
|
||||
.one_or_none()
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account)
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user())
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account is not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant is not exist.")
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
return decorated
|
||||
|
||||
if view:
|
||||
|
||||
@@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_fields))
|
||||
}
|
||||
|
||||
|
||||
class ConversationListApi(WebApiResource):
|
||||
|
||||
@@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
|
||||
|
||||
class ConversationRenameApi(WebApiResource):
|
||||
|
||||
@marshal_with(conversation_fields)
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@@ -2,14 +2,18 @@ import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
@@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 40
|
||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
function_call = result.function_call
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
||||
model = model_instance.base_model_name
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_instance.base_model_name
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, model_instance.client)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@@ -1,107 +0,0 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@@ -49,7 +49,6 @@ Action:
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
model_instance: BaseLLM
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@@ -108,6 +107,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
@@ -145,7 +146,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@@ -157,17 +158,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@@ -52,8 +53,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
summary_model_instance: BaseLLM = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_llm:
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
@@ -64,30 +62,18 @@ class AgentExecutor:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
@@ -95,7 +81,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
@@ -104,7 +89,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
|
||||
36
api/core/chain/llm_chain.py
Normal file
36
api/core/chain/llm_chain.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import LLMResult, Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_instance: BaseLLM
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Union
|
||||
|
||||
@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
|
||||
|
||||
class Completion:
|
||||
@@ -30,7 +27,7 @@ class Completion:
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptBuilder.process_template(query)
|
||||
query = PromptTemplateParser.remove_template_variables(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
@@ -108,12 +105,14 @@ class Completion:
|
||||
retriever_from=retriever_from
|
||||
)
|
||||
|
||||
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
|
||||
|
||||
# run agent executor
|
||||
agent_execute_result = None
|
||||
if agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if query_for_agent and agent_executor:
|
||||
should_use_agent = agent_executor.should_use_agent(query_for_agent)
|
||||
if should_use_agent:
|
||||
agent_execute_result = agent_executor.run(query)
|
||||
agent_execute_result = agent_executor.run(query_for_agent)
|
||||
|
||||
# When no extra pre prompt is specified,
|
||||
# the output of the agent can be used directly as the main output content without calling LLM again
|
||||
@@ -142,6 +141,13 @@ class Completion:
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
@@ -151,14 +157,28 @@ class Completion:
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
else:
|
||||
prompt_messages = model_instance.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
model_config = app_model_config.model_dict
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
stop_words = completion_params.get("stop", [])
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=model_instance,
|
||||
@@ -167,7 +187,7 @@ class Completion:
|
||||
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words,
|
||||
stop=stop_words if stop_words else None,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
@@ -257,52 +277,3 @@ class Completion:
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
old_prompt_messages, _ = final_model_instance.get_prompt(
|
||||
mode='completion',
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
original_completion = message.answer.strip()
|
||||
|
||||
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
is_override=True if message.override_model_configs else False,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=final_model_instance,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
final_model_instance.run(
|
||||
messages=prompt_messages,
|
||||
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -74,10 +74,10 @@ class ConversationMessageTask:
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(**prompt_inputs)
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@@ -94,7 +94,7 @@ class ConversationMessageTask:
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_id=self.app.id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
@@ -112,10 +112,10 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(self.conversation)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_id=self.app.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
@@ -140,7 +140,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(self.message)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
@@ -150,12 +150,12 @@ class ConversationMessageTask:
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
@@ -163,7 +163,7 @@ class ConversationMessageTask:
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(
|
||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
@@ -191,12 +191,13 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
|
||||
message_chain.output = json.dumps(chain_result.completion)
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_chain(message_chain)
|
||||
|
||||
@@ -217,7 +218,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
self._pub_handler.pub_agent_thought(message_agent_thought)
|
||||
|
||||
@@ -225,15 +226,15 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
@@ -249,7 +250,7 @@ class ConversationMessageTask:
|
||||
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
|
||||
message_agent_thought.total_price = loop_total_price
|
||||
message_agent_thought.currency = agent_model_instance.get_currency()
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
|
||||
dataset_query = DatasetQuery(
|
||||
@@ -262,6 +263,7 @@ class ConversationMessageTask:
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
|
||||
def on_dataset_query_finish(self, resource: List):
|
||||
if resource and len(resource) > 0:
|
||||
@@ -285,7 +287,7 @@ class ConversationMessageTask:
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def message_end(self):
|
||||
|
||||
@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
||||
SEARCH_URL = "https://api.notion.com/v1/search"
|
||||
|
||||
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
||||
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
||||
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
||||
|
||||
@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@@ -44,78 +43,19 @@ class LLMGenerator:
|
||||
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
max_context_token_length = max_context_token_length if max_context_token_length else 1500
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
if not context:
|
||||
return '[message too long, no summary]'
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
|
||||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt = JinjaPromptTemplate(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
|
||||
input_variables=["histories"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
prompt_template = PromptTemplateParser(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
prompt = prompt_template.format({
|
||||
"histories": histories,
|
||||
"format_instructions": format_instructions
|
||||
})
|
||||
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
@@ -128,10 +68,10 @@ class LLMGenerator:
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
questions = output_parser.parse(output.content)
|
||||
except LLMError:
|
||||
questions = []
|
||||
@@ -145,19 +85,21 @@ class LLMGenerator:
|
||||
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
prompt = OutLinePromptTemplate(
|
||||
template=output_parser.get_format_instructions(),
|
||||
input_variables=["audiences", "hoping_to_solve"],
|
||||
partial_variables={
|
||||
"variable": '{variable}',
|
||||
"lanA": '{lanA}',
|
||||
"lanB": '{lanB}',
|
||||
"topic": '{topic}'
|
||||
},
|
||||
validate_template=False
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=output_parser.get_format_instructions()
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"audiences": audiences,
|
||||
"hoping_to_solve": hoping_to_solve,
|
||||
"variable": "{{variable}}",
|
||||
"lanA": "{{lanA}}",
|
||||
"lanB": "{{lanB}}",
|
||||
"topic": "{{topic}}"
|
||||
},
|
||||
remove_template_variables=False
|
||||
)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
@@ -167,10 +109,10 @@ class LLMGenerator:
|
||||
)
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except LLMError as e:
|
||||
raise e
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import openai
|
||||
|
||||
@@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
for text_chunk in chunks:
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
if pre_segment_data['keywords']:
|
||||
segment.keywords = pre_segment_data['keywords']
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||
pre_segment_data['keywords'])
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
index: KeywordTableIndex
|
||||
search_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex):
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
@@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
858
api/core/index/vector_index/milvus.py
Normal file
858
api/core/index/vector_index/milvus.py
Normal file
@@ -0,0 +1,858 @@
|
||||
"""Wrapper around the Milvus vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MILVUS_CONNECTION = {
|
||||
"host": "localhost",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
}
|
||||
|
||||
|
||||
class Milvus(VectorStore):
|
||||
"""Initialize wrapper around the milvus vector database.
|
||||
|
||||
In order to use this you need to have `pymilvus` installed and a
|
||||
running Milvus
|
||||
|
||||
See the following documentation for how to run a Milvus instance:
|
||||
https://milvus.io/docs/install_standalone-docker.md
|
||||
|
||||
If looking for a hosted Milvus, take a look at this documentation:
|
||||
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
|
||||
this project,
|
||||
|
||||
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function used to embed the text.
|
||||
collection_name (str): Which Milvus collection to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||
this class comes in the form of a dict.
|
||||
consistency_level (str): The consistency level to use for a collection.
|
||||
Defaults to "Session".
|
||||
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||
HNSW/AUTOINDEX depending on service.
|
||||
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||
default of index.
|
||||
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||
to False.
|
||||
|
||||
The connection args used for this class comes in the form of a dict,
|
||||
here are a few of the options:
|
||||
address (str): The actual address of Milvus
|
||||
instance. Example address: "localhost:19530"
|
||||
uri (str): The uri of Milvus instance. Example uri:
|
||||
"http://randomwebsite:19530",
|
||||
"tcp:foobarsite:19530",
|
||||
"https://ok.s3.south.com:19530".
|
||||
host (str): The host of Milvus instance. Default at "localhost",
|
||||
PyMilvus will fill in the default host if only port is provided.
|
||||
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
|
||||
will fill in the default port if only host is provided.
|
||||
user (str): Use which user to connect to Milvus instance. If user and
|
||||
password are provided, we will add related header in every RPC call.
|
||||
password (str): Required when user is provided. The password
|
||||
corresponding to the user.
|
||||
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||
client_key_path (str): If use tls two-way authentication, need to
|
||||
write the client.key path.
|
||||
client_pem_path (str): If use tls two-way authentication, need to
|
||||
write the client.pem path.
|
||||
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||
the ca.pem path.
|
||||
server_pem_path (str): If use tls one-way authentication, need to
|
||||
write the server.pem path.
|
||||
server_name (str): If use tls, need to write the common name.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import Milvus
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embedding = OpenAIEmbeddings()
|
||||
# Connect to a milvus instance on localhost
|
||||
milvus_store = Milvus(
|
||||
embedding_function = Embeddings,
|
||||
collection_name = "LangChainCollection",
|
||||
drop_old = True,
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the pymilvus python package is not installed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: Optional[bool] = False,
|
||||
):
|
||||
"""Initialize the Milvus vector store."""
|
||||
try:
|
||||
from pymilvus import Collection, utility
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
|
||||
# Default search params when one is not provided.
|
||||
self.default_search_params = {
|
||||
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
|
||||
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
|
||||
"AUTOINDEX": {"metric_type": "L2", "params": {}},
|
||||
}
|
||||
|
||||
self.embedding_func = embedding_function
|
||||
self.collection_name = collection_name
|
||||
self.index_params = index_params
|
||||
self.search_params = search_params
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
# In order for a collection to be compatible, pk needs to be auto'id and int
|
||||
self._primary_field = "id"
|
||||
# In order for compatibility, the text field will need to be called "text"
|
||||
self._text_field = "page_content"
|
||||
# In order for compatibility, the vector field needs to be called "vector"
|
||||
self._vector_field = "vectors"
|
||||
# In order for compatibility, the metadata field will need to be called "metadata"
|
||||
self._metadata_field = "metadata"
|
||||
self.fields: list[str] = []
|
||||
# Create the connection to the server
|
||||
if connection_args is None:
|
||||
connection_args = DEFAULT_MILVUS_CONNECTION
|
||||
self.alias = self._create_connection_alias(connection_args)
|
||||
self.col: Optional[Collection] = None
|
||||
|
||||
# Grab the existing collection if it exists
|
||||
if utility.has_collection(self.collection_name, using=self.alias):
|
||||
self.col = Collection(
|
||||
self.collection_name,
|
||||
using=self.alias,
|
||||
)
|
||||
# If need to drop old, drop it
|
||||
if drop_old and isinstance(self.col, Collection):
|
||||
self.col.drop()
|
||||
self.col = None
|
||||
|
||||
# Initialize the vector store
|
||||
self._init()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||
"""Create the connection to the Milvus server."""
|
||||
from pymilvus import MilvusException, connections
|
||||
|
||||
# Grab the connection arguments that are used for checking existing connection
|
||||
host: str = connection_args.get("host", None)
|
||||
port: Union[str, int] = connection_args.get("port", None)
|
||||
address: str = connection_args.get("address", None)
|
||||
uri: str = connection_args.get("uri", None)
|
||||
user = connection_args.get("user", None)
|
||||
|
||||
# Order of use is host/port, uri, address
|
||||
if host is not None and port is not None:
|
||||
given_address = str(host) + ":" + str(port)
|
||||
elif uri is not None:
|
||||
given_address = uri.split("https://")[1]
|
||||
elif address is not None:
|
||||
given_address = address
|
||||
else:
|
||||
given_address = None
|
||||
logger.debug("Missing standard address type for reuse atttempt")
|
||||
|
||||
# User defaults to empty string when getting connection info
|
||||
if user is not None:
|
||||
tmp_user = user
|
||||
else:
|
||||
tmp_user = ""
|
||||
|
||||
# If a valid address was given, then check if a connection exists
|
||||
if given_address is not None:
|
||||
for con in connections.list_connections():
|
||||
addr = connections.get_connection_addr(con[0])
|
||||
if (
|
||||
con[1]
|
||||
and ("address" in addr)
|
||||
and (addr["address"] == given_address)
|
||||
and ("user" in addr)
|
||||
and (addr["user"] == tmp_user)
|
||||
):
|
||||
logger.debug("Using previous connection: %s", con[0])
|
||||
return con[0]
|
||||
|
||||
# Generate a new connection if one doesn't exist
|
||||
alias = uuid4().hex
|
||||
try:
|
||||
connections.connect(alias=alias, **connection_args)
|
||||
logger.debug("Created new connection using: %s", alias)
|
||||
return alias
|
||||
except MilvusException as e:
|
||||
logger.error("Failed to create new connection using: %s", alias)
|
||||
raise e
|
||||
|
||||
def _init(
|
||||
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
if embeddings is not None:
|
||||
self._create_collection(embeddings, metadatas)
|
||||
self._extract_fields()
|
||||
self._create_index()
|
||||
self._create_search_params()
|
||||
self._load()
|
||||
|
||||
def _create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
MilvusException,
|
||||
)
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
# Determine metadata schema
|
||||
# if metadatas:
|
||||
# # Create FieldSchema for each entry in metadata.
|
||||
# for key, value in metadatas[0].items():
|
||||
# # Infer the corresponding datatype of the metadata
|
||||
# dtype = infer_dtype_bydata(value)
|
||||
# # Datatype isn't compatible
|
||||
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
|
||||
# logger.error(
|
||||
# "Failure to create collection, unrecognized dtype for key: %s",
|
||||
# key,
|
||||
# )
|
||||
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||
# # Dataype is a string/varchar equivalent
|
||||
# elif dtype == DataType.VARCHAR:
|
||||
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
|
||||
# else:
|
||||
# fields.append(FieldSchema(key, dtype))
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create the collection
|
||||
try:
|
||||
self.col = Collection(
|
||||
name=self.collection_name,
|
||||
schema=schema,
|
||||
consistency_level=self.consistency_level,
|
||||
using=self.alias,
|
||||
)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create collection: %s error: %s", self.collection_name, e
|
||||
)
|
||||
raise e
|
||||
|
||||
def _extract_fields(self) -> None:
|
||||
"""Grab the existing fields from the Collection"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self.fields.remove(self._primary_field)
|
||||
|
||||
def _get_index(self) -> Optional[dict[str, Any]]:
|
||||
"""Return the vector index information if it exists"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
for x in self.col.indexes:
|
||||
if x.field_name == self._vector_field:
|
||||
return x.to_dict()
|
||||
return None
|
||||
|
||||
def _create_index(self) -> None:
|
||||
"""Create a index on the collection"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||
try:
|
||||
# If no index params, use a default HNSW based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
|
||||
# If default did not work, most likely on Zilliz Cloud
|
||||
except MilvusException:
|
||||
# Use AUTOINDEX based index
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {},
|
||||
}
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully created an index on collection: %s",
|
||||
self.collection_name,
|
||||
)
|
||||
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create an index on collection: %s", self.collection_name
|
||||
)
|
||||
raise e
|
||||
|
||||
def _create_search_params(self) -> None:
|
||||
"""Generate search params based on the current index type"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self.search_params is None:
|
||||
index = self._get_index()
|
||||
if index is not None:
|
||||
index_type: str = index["index_param"]["index_type"]
|
||||
metric_type: str = index["index_param"]["metric_type"]
|
||||
self.search_params = self.default_search_params[index_type]
|
||||
self.search_params["metric_type"] = metric_type
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load the collection if available."""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is not None:
|
||||
self.col.load()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Insert text data into Milvus.
|
||||
|
||||
Inserting data when the collection has not be made yet will result
|
||||
in creating a new Collection. The data of the first entity decides
|
||||
the schema of the new collection, the dim is extracted from the first
|
||||
embedding and the columns are decided by the first metadata dict.
|
||||
Metada keys will need to be present for all inserted values. At
|
||||
the moment there is no None equivalent in Milvus.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): The texts to embed, it is assumed
|
||||
that they all fit in memory.
|
||||
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
|
||||
the texts. Defaults to None.
|
||||
timeout (Optional[int]): Timeout for each batch insert. Defaults
|
||||
to None.
|
||||
batch_size (int, optional): Batch size to use for insertion.
|
||||
Defaults to 1000.
|
||||
|
||||
Raises:
|
||||
MilvusException: Failure to add texts
|
||||
|
||||
Returns:
|
||||
List[str]: The resulting keys for each inserted element.
|
||||
"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
texts = list(texts)
|
||||
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
|
||||
if len(embeddings) == 0:
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# If the collection hasn't been initialized yet, perform all steps to do so
|
||||
if not isinstance(self.col, Collection):
|
||||
self._init(embeddings, metadatas)
|
||||
|
||||
# Dict to hold all insert columns
|
||||
insert_dict: dict[str, list] = {
|
||||
self._text_field: texts,
|
||||
self._vector_field: embeddings,
|
||||
}
|
||||
|
||||
# Collect the metadata into the insert dict.
|
||||
# if metadatas is not None:
|
||||
# for d in metadatas:
|
||||
# for key, value in d.items():
|
||||
# if key in self.fields:
|
||||
# insert_dict.setdefault(key, []).append(value)
|
||||
if metadatas is not None:
|
||||
for d in metadatas:
|
||||
insert_dict.setdefault(self._metadata_field, []).append(d)
|
||||
|
||||
# Total insert count
|
||||
vectors: list = insert_dict[self._vector_field]
|
||||
total_count = len(vectors)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
assert isinstance(self.col, Collection)
|
||||
for i in range(0, total_count, batch_size):
|
||||
# Grab end index
|
||||
end = min(i + batch_size, total_count)
|
||||
# Convert dict to list of lists batch for insertion
|
||||
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
res: Collection
|
||||
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
|
||||
pks.extend(res.primary_keys)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
query (str): The text to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score(
|
||||
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[float], List[Tuple[Document, any, any]]:
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
# Embed the query text.
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return res
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ret = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
|
||||
pair = (doc, result.score)
|
||||
ret.append(pair)
|
||||
|
||||
return ret
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
param=param,
|
||||
expr=expr,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
embedding (str): The embedding vector being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=fetch_k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ids = []
|
||||
documents = []
|
||||
scores = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||||
documents.append(doc)
|
||||
scores.append(result.score)
|
||||
ids.append(result.id)
|
||||
|
||||
vectors = self.col.query(
|
||||
expr=f"{self._primary_field} in {ids}",
|
||||
output_fields=[self._primary_field, self._vector_field],
|
||||
timeout=timeout,
|
||||
)
|
||||
# Reorganize the results from query to match search order.
|
||||
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
|
||||
|
||||
ordered_result_embeddings = [vectors[x] for x in ids]
|
||||
|
||||
# Get the new order of results.
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
# Reorder the values and return.
|
||||
ret = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
else:
|
||||
ret.append(documents[x])
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
batch_size: int = 100,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Milvus:
|
||||
"""Create a Milvus collection, indexes it with HNSW, and insert data.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||
to DEFAULT_MILVUS_CONNECTION.
|
||||
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||
to "Session".
|
||||
index_params (Optional[dict], optional): Which index_params to use. Defaults
|
||||
to None.
|
||||
search_params (Optional[dict], optional): Which search params to use.
|
||||
Defaults to None.
|
||||
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||
that name if it exists. Defaults to False.
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 100
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
|
||||
Returns:
|
||||
Milvus: Milvus Vector Store
|
||||
"""
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args,
|
||||
consistency_level=consistency_level,
|
||||
index_params=index_params,
|
||||
search_params=search_params,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
|
||||
return vector_db
|
||||
@@ -9,30 +9,44 @@ from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
endpoint: str
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config MILVUS_ENDPOINT is required")
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
@@ -49,7 +63,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
@@ -58,26 +71,29 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
connection_args=self._client_config.to_milvus_params(),
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=collection_name,
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
return self
|
||||
@@ -86,42 +102,53 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embedding_function=self._embeddings,
|
||||
connection_args=self._client_config.to_milvus_params()
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_document_id(document_id)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_doc_ids(doc_ids)
|
||||
vector_store.del_texts({
|
||||
'filter': f' id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return False
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
@@ -1390,70 +1390,12 @@ class Qdrant(VectorStore):
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Skip any validation in case of forced collection recreate.
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
all_collection_name = []
|
||||
collections_response = client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
@@ -1481,6 +1423,67 @@ class Qdrant(VectorStore):
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
is_new_collection = True
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
|
||||
@@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
|
||||
@@ -47,6 +47,20 @@ class VectorIndex:
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
|
||||
|
||||
return MilvusVectorIndex(
|
||||
dataset=dataset,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
@@ -79,6 +80,8 @@ class IndexingRunner:
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = 'error'
|
||||
@@ -276,13 +279,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -372,13 +376,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@@ -582,7 +587,6 @@ class IndexingRunner:
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
@@ -734,6 +738,9 @@ class IndexingRunner:
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
|
||||
update_params = {
|
||||
DatasetDocument.indexing_status: after_indexing_status
|
||||
@@ -781,3 +788,7 @@ class IndexingRunner:
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
pass
|
||||
|
||||
@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
|
||||
@@ -51,6 +51,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'baichuan':
|
||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
||||
return BaichuanProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class HuggingfaceEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = HuggingfaceHubEmbeddings(
|
||||
model=name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class OpenLLMEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenLLMEmbeddings(
|
||||
server_url=credentials['server_url']
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM embedding: {str(ex)}")
|
||||
@@ -1,5 +1,4 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
@@ -21,7 +20,4 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -9,26 +9,31 @@ class LLMRunResult(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
source: list = None
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
type: MessageType = MessageType.USER
|
||||
content: str = ''
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
if message.type == MessageType.USER:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
additional_kwargs = {}
|
||||
if message.function_call:
|
||||
additional_kwargs['function_call'] = message.function_call
|
||||
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
@@ -39,11 +44,21 @@ def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
message_kwargs = {
|
||||
'content': message.content,
|
||||
'type': MessageType.ASSISTANT
|
||||
}
|
||||
|
||||
if 'function_call' in message.additional_kwargs:
|
||||
message_kwargs['function_call'] = message.additional_kwargs['function_call']
|
||||
|
||||
prompt_messages.append(PromptMessage(**message_kwargs))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
elif isinstance(message, FunctionMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
|
||||
@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
|
||||
67
api/core/model_providers/models/llm/baichuan_model.py
Normal file
67
api/core/model_providers/models/llm/baichuan_model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
|
||||
|
||||
class BaichuanModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return BaichuanChatLLM(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Baichuan: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
import decimal
|
||||
@@ -12,14 +13,17 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
|
||||
to_lc_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -154,8 +158,11 @@ class BaseLLM(BaseProviderModel):
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
function_call = None
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
if 'function_call' in result.generations[0][0].message.additional_kwargs:
|
||||
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
@@ -188,7 +195,8 @@ class BaseLLM(BaseProviderModel):
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
completion_tokens=completion_tokens,
|
||||
function_call=function_call
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -224,7 +232,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@@ -242,7 +250,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@@ -257,7 +265,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
@@ -322,6 +330,85 @@ class BaseLLM(BaseProviderModel):
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def get_advanced_prompt(self, app_mode: str,
|
||||
app_model_config: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
|
||||
|
||||
model_mode = app_model_config.model_dict['mode']
|
||||
conversation_histories_role = {}
|
||||
|
||||
raw_prompt_list = []
|
||||
prompt_messages = []
|
||||
|
||||
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
|
||||
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
raw_prompt_list = [{
|
||||
'role': MessageType.USER.value,
|
||||
'text': prompt_text
|
||||
}]
|
||||
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
|
||||
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
|
||||
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
raw_prompt_list = [{
|
||||
'role': MessageType.USER.value,
|
||||
'text': prompt_text
|
||||
}]
|
||||
else:
|
||||
raise Exception("app_mode or model_mode not support")
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
prompt = prompt_item['text']
|
||||
|
||||
# set prompt template variables
|
||||
prompt_template = PromptTemplateParser(template=prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
if '#context#' in prompt:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
|
||||
if '#query#' in prompt:
|
||||
if query:
|
||||
prompt_inputs['#query#'] = query
|
||||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
|
||||
if '#histories#' in prompt:
|
||||
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
|
||||
memory.human_prefix = conversation_histories_role['user_prefix']
|
||||
memory.ai_prefix = conversation_histories_role['assistant_prefix']
|
||||
histories = self._get_history_messages_from_memory(memory, 2000)
|
||||
prompt_inputs['#histories#'] = histories
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
prompt = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
memory.human_prefix = MessageType.USER.value
|
||||
memory.ai_prefix = MessageType.ASSISTANT.value
|
||||
histories = self._get_history_messages_list_from_memory(memory, 2000)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
|
||||
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'common_completion'
|
||||
@@ -334,17 +421,17 @@ class BaseLLM(BaseProviderModel):
|
||||
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
context=context
|
||||
{'context': context}
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
**prompt_inputs
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
@@ -377,10 +464,8 @@ class BaseLLM(BaseProviderModel):
|
||||
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format(
|
||||
histories=histories
|
||||
)
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format({'histories': histories})
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
@@ -391,10 +476,8 @@ class BaseLLM(BaseProviderModel):
|
||||
elif order == 'histories_prompt':
|
||||
prompt += histories_prompt_content
|
||||
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format(
|
||||
query=query
|
||||
)
|
||||
prompt_template = PromptTemplateParser(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format({'query': query})
|
||||
|
||||
prompt += query_prompt_content
|
||||
|
||||
@@ -425,6 +508,16 @@ class BaseLLM(BaseProviderModel):
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory.return_messages = True
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
memory.return_messages = False
|
||||
return to_prompt_messages(external_context[memory_key])
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
@@ -439,16 +532,7 @@ class BaseLLM(BaseProviderModel):
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(messages)
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
return MinimaxChatLLM(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional, Any
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from openai import api_requestor
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
@@ -105,7 +106,21 @@ class OpenAIModel(BaseLLM):
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
return EnhanceTongyi(
|
||||
model_name=self.name,
|
||||
max_retries=1,
|
||||
@@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
@@ -18,6 +18,7 @@ class WenxinModel(BaseLLM):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
|
||||
@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
|
||||
AZURE_OPENAI_API_VERSION
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
|
||||
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
|
||||
|
||||
if credentials['base_model_name'] in [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
return model_list
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
if model_name == 'text-davinci-003':
|
||||
return ModelMode.COMPLETION.value
|
||||
else:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
171
api/core/model_providers/providers/baichuan_provider.py
Normal file
171
api/core/model_providers/providers/baichuan_provider.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class BaichuanProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'baichuan'
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'baichuan2-53b',
|
||||
'name': 'Baichuan2-53B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = BaichuanModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
|
||||
|
||||
if 'secret_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
'secret_key': credentials['secret_key'],
|
||||
}
|
||||
|
||||
llm = BaichuanChatLLM(
|
||||
temperature=0,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['secret_key']:
|
||||
credentials['secret_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['secret_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
return [{
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
} for provider_model in provider_models]
|
||||
provider_model_list = []
|
||||
for provider_model in provider_models:
|
||||
provider_model_dict = {
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
}
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
|
||||
|
||||
provider_model_list.append(provider_model_dict)
|
||||
|
||||
return provider_model_list
|
||||
|
||||
@abstractmethod
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
"""
|
||||
get text generation model mode.
|
||||
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_type: ModelType) -> Type:
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import json
|
||||
from typing import Type
|
||||
import requests
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
|
||||
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
|
||||
from models.provider import ProviderType
|
||||
|
||||
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
|
||||
|
||||
|
||||
class HuggingfaceHubProvider(BaseModelProvider):
|
||||
@property
|
||||
@@ -24,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@@ -33,6 +41,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = HuggingfaceHubModel
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
model_class = HuggingfaceEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -63,7 +73,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
if model_type != ModelType.TEXT_GENERATION:
|
||||
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'huggingfacehub_api_type' not in credentials \
|
||||
@@ -88,19 +98,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
||||
'text-generation, summarization.')
|
||||
'text-generation, feature-extraction.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
if credentials['task_type'] == 'feature-extraction':
|
||||
cls.check_embedding_valid(credentials, model_name)
|
||||
else:
|
||||
cls.check_llm_valid(credentials)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
else:
|
||||
@@ -112,13 +118,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction")
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {model_name} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
|
||||
|
||||
@classmethod
|
||||
def check_llm_valid(cls, credentials: dict):
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
|
||||
@classmethod
|
||||
def check_embedding_valid(cls, credentials: dict, model_name: str):
|
||||
|
||||
cls.check_endpoint_url_model_repository_name(credentials, model_name)
|
||||
|
||||
embedding_model = HuggingfaceHubEmbeddings(
|
||||
model=model_name,
|
||||
**credentials
|
||||
)
|
||||
|
||||
embedding_model.embed_query("ping")
|
||||
|
||||
@classmethod
|
||||
def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response =requests.get(url=url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError('User Name or Organization Name is invalid.')
|
||||
|
||||
model_repository_name = ''
|
||||
|
||||
for item in response.json().get("items", []):
|
||||
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
|
||||
model_repository_name = item.get("model", {}).get("repository")
|
||||
break
|
||||
|
||||
if model_repository_name != model_name:
|
||||
raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
|
||||
from core.model_providers.models.llm.localai_model import LocalAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
return ModelMode.CHAT.value
|
||||
else:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
@@ -2,14 +2,15 @@ import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.minimax_model import MinimaxModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
|
||||
from models.provider import ProviderType, ProviderQuotaType
|
||||
|
||||
|
||||
@@ -28,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'abab5.5-chat',
|
||||
'name': 'abab5.5-chat',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'abab5-chat',
|
||||
'name': 'abab5-chat',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
@@ -44,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@@ -98,14 +104,14 @@ class MinimaxProvider(BaseModelProvider):
|
||||
'minimax_api_key': credentials['minimax_api_key'],
|
||||
}
|
||||
|
||||
llm = Minimax(
|
||||
llm = MinimaxChatLLM(
|
||||
model='abab5.5-chat',
|
||||
max_tokens=10,
|
||||
temperature=0.01,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user