mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 06:48:28 +00:00
Compare commits
76 Commits
chore/upgr
...
feat/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9977b8683e | ||
|
|
d4be5ef9de | ||
|
|
1374be5a31 | ||
|
|
b2bbc28580 | ||
|
|
59b3e672aa | ||
|
|
a2f8bce8f5 | ||
|
|
a2b9adb3a2 | ||
|
|
28067640b5 | ||
|
|
da67916843 | ||
|
|
e54ce479ad | ||
|
|
6024d8a42d | ||
|
|
f565f08aa0 | ||
|
|
fd4afe09f8 | ||
|
|
dd0904f95c | ||
|
|
4c3076f2a4 | ||
|
|
1e73f63ff8 | ||
|
|
d167d5b1be | ||
|
|
71fa14f791 | ||
|
|
8dd1873e76 | ||
|
|
f91f5c7401 | ||
|
|
c62b7cc679 | ||
|
|
3ee213ddca | ||
|
|
8429877b02 | ||
|
|
05a0faff6a | ||
|
|
e09f6e4987 | ||
|
|
e23f4b0265 | ||
|
|
f582d4a13e | ||
|
|
2f41bd495d | ||
|
|
162a8c4393 | ||
|
|
3d1ce4c53f | ||
|
|
6db3ae9b8e | ||
|
|
6d0cb9dc33 | ||
|
|
46e95e8309 | ||
|
|
a7b9375877 | ||
|
|
0c6a8a130e | ||
|
|
9903f1e703 | ||
|
|
6fad719e42 | ||
|
|
9aaee8ee47 | ||
|
|
166221d784 | ||
|
|
925d69a2ee | ||
|
|
5ff08e241a | ||
|
|
3defd24087 | ||
|
|
9d86147d20 | ||
|
|
80801ac4ab | ||
|
|
210926cd91 | ||
|
|
677a69deed | ||
|
|
8dfdee21ce | ||
|
|
6ea77ab4cd | ||
|
|
e3c996688d | ||
|
|
bc3a570dda | ||
|
|
0800021a2d | ||
|
|
435eddd867 | ||
|
|
6e0fb055d1 | ||
|
|
1e9ac7ffeb | ||
|
|
b4873ecb43 | ||
|
|
1859d57784 | ||
|
|
69d58fbb50 | ||
|
|
cb34991663 | ||
|
|
c700364e1c | ||
|
|
9a6b1dc3a1 | ||
|
|
54b5b80a07 | ||
|
|
831459b895 | ||
|
|
4e101604c3 | ||
|
|
a6455269f0 | ||
|
|
cd257b91c5 | ||
|
|
d8f57bf899 | ||
|
|
989fb11fd7 | ||
|
|
140965b738 | ||
|
|
14ee51aead | ||
|
|
2e97ba5700 | ||
|
|
f549d53b68 | ||
|
|
a085ad4719 | ||
|
|
f230a9232e | ||
|
|
e84bf35e2a | ||
|
|
20f090537f | ||
|
|
dbe7a7c4fd |
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@@ -8,7 +8,7 @@ inputs:
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '1.8.4'
|
||||
default: '2.0.1'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
16
.github/workflows/api-tests.yml
vendored
16
.github/workflows/api-tests.yml
vendored
@@ -42,25 +42,23 @@ jobs:
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Check dependencies in pyproject.toml
|
||||
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
pushd api
|
||||
poetry run python -m mypy --install-types --non-interactive .
|
||||
popd
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
@@ -80,4 +78,4 @@ jobs:
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
6
.github/workflows/style.yml
vendored
6
.github/workflows/style.yml
vendored
@@ -38,12 +38,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
poetry run -C api ruff check ./
|
||||
poetry run -C api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -70,4 +70,4 @@ jobs:
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
@@ -53,10 +53,12 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
|
||||
@@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.4
|
||||
ENV POETRY_VERSION=2.0.1
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
@@ -79,5 +79,5 @@
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
@@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
description="Base URL for the console web interface,used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.0",
|
||||
default="0.15.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
|
||||
@@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
@@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
@@ -52,12 +52,12 @@ class DatasetListApi(Resource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
@@ -457,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -619,8 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
@@ -645,6 +644,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -350,8 +350,7 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -526,8 +525,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -168,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -217,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -267,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -368,9 +365,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row[0]}
|
||||
data = {"content": row.iloc[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@@ -437,8 +434,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -7,4 +7,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
from .dataset import dataset, document, hit_testing, segment, upload_file
|
||||
|
||||
@@ -31,8 +31,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
@@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
54
api/controllers/service_api/dataset/upload_file.py
Normal file
54
api/controllers/service_api/dataset/upload_file.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@@ -195,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
@@ -236,7 +240,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
|
||||
@@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
|
||||
@@ -167,8 +167,7 @@ class AppQueueManager:
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == "normal",
|
||||
Conversation.is_deleted.is_(False),
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
||||
@@ -145,7 +145,7 @@ class MessageCycleManage:
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
|
||||
@@ -62,8 +62,9 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid".format(self.variable)
|
||||
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
|
||||
self.variable
|
||||
)
|
||||
)
|
||||
|
||||
# decrypt api_key
|
||||
|
||||
@@ -90,7 +90,7 @@ class File(BaseModel):
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
text = f""
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
|
||||
@@ -530,7 +530,6 @@ class IndexingRunner:
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 10
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@@ -539,11 +538,22 @@ class IndexingRunner:
|
||||
)
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i : i + chunk_size]
|
||||
|
||||
# Distribute documents into multiple groups based on the hash values of page_content
|
||||
# This is done to prevent multiple threads from processing the same document,
|
||||
# Thereby avoiding potential database insertion deadlocks
|
||||
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
|
||||
for document in documents:
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
group_index = int(hash, 16) % max_workers
|
||||
document_groups[group_index].append(document)
|
||||
for chunk_documents in document_groups:
|
||||
if len(chunk_documents) == 0:
|
||||
continue
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self._process_chunk,
|
||||
|
||||
@@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
@@ -43,5 +46,6 @@ class GPT2Tokenizer:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
|
||||
return _tokenizer
|
||||
|
||||
@@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
||||
@@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
@@ -44,6 +44,7 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
ja_JP: AWS リージョン
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
@@ -51,62 +52,77 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
ja_JP: 米国 (バージニア北部)
|
||||
- value: us-east-2
|
||||
label:
|
||||
en_US: US East (Ohio)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
zh_Hans: 美国东部 (俄亥俄)
|
||||
ja_JP: 米国 (オハイオ)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
ja_JP: 米国 (オレゴン)
|
||||
- value: ap-south-1
|
||||
label:
|
||||
en_US: Asia Pacific (Mumbai)
|
||||
zh_Hans: 亚太地区(孟买)
|
||||
ja_JP: アジアパシフィック (ムンバイ)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
ja_JP: アジアパシフィック (シンガポール)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
ja_JP: アジアパシフィック (シドニー)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
ja_JP: アジアパシフィック (東京)
|
||||
- value: ap-northeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Seoul)
|
||||
zh_Hans: 亚太地区(首尔)
|
||||
ja_JP: アジアパシフィック (ソウル)
|
||||
- value: ca-central-1
|
||||
label:
|
||||
en_US: Canada (Central)
|
||||
zh_Hans: 加拿大(中部)
|
||||
ja_JP: カナダ (中部)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
ja_JP: 欧州 (フランクフルト)
|
||||
- value: eu-west-1
|
||||
label:
|
||||
en_US: Europe (Ireland)
|
||||
zh_Hans: 欧洲(爱尔兰)
|
||||
ja_JP: 欧州 (アイルランド)
|
||||
- value: eu-west-2
|
||||
label:
|
||||
en_US: Europe (London)
|
||||
zh_Hans: 欧洲西部 (伦敦)
|
||||
ja_JP: 欧州 (ロンドン)
|
||||
- value: eu-west-3
|
||||
label:
|
||||
en_US: Europe (Paris)
|
||||
zh_Hans: 欧洲(巴黎)
|
||||
ja_JP: 欧州 (パリ)
|
||||
- value: sa-east-1
|
||||
label:
|
||||
en_US: South America (São Paulo)
|
||||
zh_Hans: 南美洲(圣保罗)
|
||||
ja_JP: 南米 (サンパウロ)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
ja_JP: AWS GovCloud (米国西部)
|
||||
- variable: model_for_validation
|
||||
required: false
|
||||
label:
|
||||
|
||||
@@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": top_n,
|
||||
"numberOfResults": min(top_n, len(text_sources)),
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
|
||||
@@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
model: deepseek-reasoner
|
||||
label:
|
||||
zh_Hans: deepseek-reasoner
|
||||
en_US: deepseek-reasoner
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 8192
|
||||
default: 4096
|
||||
pricing:
|
||||
input: "4"
|
||||
output: "16"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
@@ -24,9 +27,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@@ -39,3 +39,208 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
credentials["stream_function_calling"] = "support"
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
is_reasoning_started = False # Add flag to track reasoning state
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
# handle the error here. for issue #11629
|
||||
if chunk_json.get("error") and chunk_json.get("choices") is None:
|
||||
raise ValueError(chunk_json.get("error"))
|
||||
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
is_reasoning = delta.get("reasoning_content")
|
||||
delta_content = delta.get("content") or delta.get("reasoning_content")
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
elif (
|
||||
"function_call" in delta
|
||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
||||
):
|
||||
assistant_message_tool_calls = [
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# Add markdown quote markers for reasoning content
|
||||
if is_reasoning:
|
||||
if not is_reasoning_started:
|
||||
delta_content = "> 💭 " + delta_content
|
||||
is_reasoning_started = True
|
||||
elif "\n\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n\n", "\n> ")
|
||||
elif "\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n", "\n> ")
|
||||
elif is_reasoning_started:
|
||||
# If we were in reasoning mode but now getting regular content,
|
||||
# add \n\n to close the reasoning block
|
||||
delta_content = "\n\n" + delta_content
|
||||
is_reasoning_started = False
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
# reset tool calls
|
||||
tool_calls = []
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 01-21
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
@staticmethod
|
||||
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
|
||||
headers = {
|
||||
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
"Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
"minimax-text-01": MinimaxChatCompletionPro,
|
||||
"abab7-chat-preview": MinimaxChatCompletionPro,
|
||||
"abab6.5t-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
model: minimax-text-01
|
||||
label:
|
||||
en_US: Minimax-Text-01
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1000192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 1000192
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -44,9 +44,6 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
<svg width="162" height="36" viewBox="0 0 162 36" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2 0C0.895431 0 0 0.895432 0 2V29.1891C0 30.2937 0.895433 31.1891 2 31.1891H5.51171L16.0608 35.1377C16.7145 35.3824 17.4114 34.8991 17.4114 34.2012V11.3669C17.4114 10.533 16.894 9.78665 16.1131 9.49405L5.51171 5.52152H25.58V31.1891H29.0917C30.1963 31.1891 31.0917 30.2937 31.0917 29.1891V2C31.0917 0.895431 30.1963 0 29.0917 0H2ZM14.6022 23.7351C15.0558 23.956 15.4239 23.6812 15.4239 23.1185C15.4239 22.5557 15.0558 21.9204 14.6022 21.6995C14.1486 21.4775 13.7804 21.7545 13.7804 22.3161C13.7804 22.8777 14.1486 23.513 14.6022 23.7351Z" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2 0C0.895431 0 0 0.895432 0 2V29.1891C0 30.2937 0.895433 31.1891 2 31.1891H5.51171L16.0608 35.1377C16.7145 35.3824 17.4114 34.8991 17.4114 34.2012V11.3669C17.4114 10.533 16.894 9.78665 16.1131 9.49405L5.51171 5.52152H25.58V31.1891H29.0917C30.1963 31.1891 31.0917 30.2937 31.0917 29.1891V2C31.0917 0.895431 30.1963 0 29.0917 0H2ZM14.6022 23.7351C15.0558 23.956 15.4239 23.6812 15.4239 23.1185C15.4239 22.5557 15.0558 21.9204 14.6022 21.6995C14.1486 21.4775 13.7804 21.7545 13.7804 22.3161C13.7804 22.8777 14.1486 23.513 14.6022 23.7351Z" fill="url(#paint0_linear_1473_71)"/>
|
||||
<path d="M55.9397 27.8804H59.0566V19.0803C59.0566 14.9105 56.381 12.7172 52.8228 12.7172C51.0023 12.7172 49.3197 13.4483 48.2991 14.6668V12.9609H45.1546V27.8804H48.2991V19.5406C48.2991 16.8059 49.8162 15.3978 52.1332 15.3978C54.4226 15.3978 55.9397 16.8059 55.9397 19.5406V27.8804Z" fill="#11101A"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M69.7881 12.7172C74.1187 12.7172 77.539 15.7228 77.539 20.4071C77.539 25.0915 74.0083 28.1241 69.6502 28.1241C65.3196 28.1241 62.0372 25.0915 62.0372 20.4071C62.0372 15.7228 65.4575 12.7172 69.7881 12.7172ZM69.7342 15.3979C67.362 15.3979 65.2381 17.0225 65.2381 20.4071C65.2381 23.7918 67.2793 25.4435 69.6514 25.4435C71.996 25.4435 74.313 23.7918 74.313 20.4071C74.313 17.0225 72.0788 15.3979 69.7342 15.3979Z" fill="#11101A"/>
|
||||
<path d="M78.861 12.9609L84.6259 27.8804H88.3772L94.1697 12.9609H90.8321L86.5291 25.1185L82.2261 12.9609H78.861Z" fill="#11101A"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M100.13 9.00761C100.13 10.1178 99.2477 10.9842 98.1443 10.9842C97.0134 10.9842 96.1308 10.1178 96.1308 9.00761C96.1308 7.89745 97.0134 7.03098 98.1443 7.03098C99.2477 7.03098 100.13 7.89745 100.13 9.00761ZM99.6882 27.8804H96.5437V12.9609H99.6882V27.8804Z" fill="#11101A"/>
|
||||
<path d="M104.322 23.7376C104.322 26.7702 106.004 27.8804 108.708 27.8804H111.19V25.308H109.259C107.935 25.308 107.494 24.8477 107.494 23.7376V15.479H111.19V12.9609H107.494V9.25128H104.322V12.9609H102.529V15.479H104.322V23.7376Z" fill="#11101A"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M120.154 28.1241C116.209 28.1241 113.037 24.9561 113.037 20.353C113.037 15.7498 116.209 12.7172 120.209 12.7172C122.774 12.7172 124.539 13.9086 125.477 15.1271V12.9609H128.649V27.8804H125.477V25.6601C124.512 26.9327 122.691 28.1241 120.154 28.1241ZM120.87 25.4435C123.242 25.4435 125.476 23.6293 125.476 20.4071C125.476 17.212 123.242 15.3979 120.87 15.3979C118.526 15.3979 116.264 17.1308 116.264 20.353C116.264 23.5752 118.526 25.4435 120.87 25.4435Z" fill="#11101A"/>
|
||||
<path d="M136.043 26.0933C136.043 24.9832 135.16 24.1167 134.057 24.1167C132.926 24.1167 132.043 24.9832 132.043 26.0933C132.043 27.2035 132.926 28.07 134.057 28.07C135.16 28.07 136.043 27.2035 136.043 26.0933Z" fill="#11101A"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M145.502 28.1241C141.558 28.1241 138.386 24.9561 138.386 20.353C138.386 15.7498 141.558 12.7172 145.557 12.7172C148.123 12.7172 149.888 13.9086 150.826 15.1271V12.9609H153.998V27.8804H150.826V25.6601C149.86 26.9327 148.04 28.1241 145.502 28.1241ZM146.219 25.4435C148.591 25.4435 150.825 23.6293 150.825 20.4071C150.825 17.212 148.591 15.3979 146.219 15.3979C143.874 15.3979 141.612 17.1308 141.612 20.353C141.612 23.5752 143.874 25.4435 146.219 25.4435Z" fill="#11101A"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M161.722 9.00761C161.722 10.1178 160.84 10.9842 159.736 10.9842C158.605 10.9842 157.723 10.1178 157.723 9.00761C157.723 7.89745 158.605 7.03098 159.736 7.03098C160.84 7.03098 161.722 7.89745 161.722 9.00761ZM161.28 27.8804H158.136V12.9609H161.28V27.8804Z" fill="#11101A"/>
|
||||
<svg width="88" height="24" viewBox="0 0 88 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_1923_1287)">
|
||||
<path d="M24 18.8323V18.8326H14.3246L9.16716 13.6751V18.8326H0V18.8314L9.16716 9.66422V4H9.16774L24 18.8323Z" fill="black"/>
|
||||
</g>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M73.2505 16.8061H76.5869V18.9145H73.9391C72.0857 18.9145 70.9202 17.8952 70.9202 15.9977V10.3921H69.0316V8.26609H70.9202L71.4677 5.47209H73.2329V8.26609H76.5869V10.3921H73.2505V16.8061ZM33.8133 4.85699L38.6679 15.681H38.809V4.85699H41.3333V18.9145H37.52L32.6654 8.09046H32.5243V18.9145H30V4.85699H33.8133ZM47.812 19.1254C44.7225 19.1254 42.7457 16.9641 42.7457 13.6079C42.7457 10.2517 44.6873 8.05518 47.812 8.05518C50.9367 8.05518 52.8429 10.1635 52.8429 13.6079C52.8429 17.0523 50.9014 19.1254 47.812 19.1254ZM47.812 17.017C49.1891 17.017 50.3363 16.5423 50.3715 15.1894V12.0265C50.3715 10.6383 49.2068 10.1635 47.812 10.1635C46.4172 10.1635 45.2171 10.6383 45.2171 12.0265V15.1894C45.2524 16.5599 46.4348 17.017 47.812 17.017ZM55.5444 8.24846L58.2979 16.6826H58.439L61.1926 8.24846H63.7346L59.9389 18.8968H56.7966L53.0186 8.24846H55.5429H55.5444ZM65.0419 8.26609H67.3722V18.9145H65.0419V8.26609ZM64.9001 4.85699H67.5126V6.86027H64.9001V4.85699ZM82.3064 19.143C79.4639 19.143 77.6458 16.9817 77.6458 13.6079C77.6458 10.2341 79.4286 8.07282 82.3064 8.07282C83.6483 8.07282 84.7425 8.59973 85.3958 9.58373H85.5369L85.9962 8.26609H87.7614V18.9145H85.9962L85.5369 17.6314H85.3958C84.6896 18.5625 83.5072 19.1423 82.3064 19.1423V19.143ZM82.7826 17.017C84.1774 17.017 85.3951 16.5776 85.4304 15.1894V12.0265C85.4304 10.603 84.159 10.1988 82.7297 10.1988C81.3004 10.1988 80.1172 10.6383 80.1172 12.0265V15.1894C80.1525 16.5952 81.3709 17.017 82.7826 17.017Z" fill="black"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_1473_71" x1="31" y1="-2" x2="0.975591" y2="14.2625" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2622FF"/>
|
||||
<stop offset="1" stop-color="#A717FF"/>
|
||||
</linearGradient>
|
||||
<clipPath id="clip0_1923_1287">
|
||||
<rect width="24" height="14.8326" fill="white" transform="translate(0 4)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 1.9 KiB |
@@ -1,10 +1,3 @@
|
||||
<svg width="32" height="36" viewBox="0 0 32 36" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2 0C0.895431 0 0 0.895432 0 2V29.1891C0 30.2937 0.895433 31.1891 2 31.1891H5.51171L16.0608 35.1377C16.7145 35.3824 17.4114 34.8991 17.4114 34.2012V11.3669C17.4114 10.533 16.894 9.78665 16.1131 9.49405L5.51171 5.52152H25.58V31.1891H29.0917C30.1963 31.1891 31.0917 30.2937 31.0917 29.1891V2C31.0917 0.895431 30.1963 0 29.0917 0H2ZM14.6022 23.7351C15.0558 23.956 15.4239 23.6812 15.4239 23.1185C15.4239 22.5557 15.0558 21.9204 14.6022 21.6995C14.1486 21.4775 13.7804 21.7545 13.7804 22.3161C13.7804 22.8777 14.1486 23.513 14.6022 23.7351Z" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M2 0C0.895431 0 0 0.895432 0 2V29.1891C0 30.2937 0.895433 31.1891 2 31.1891H5.51171L16.0608 35.1377C16.7145 35.3824 17.4114 34.8991 17.4114 34.2012V11.3669C17.4114 10.533 16.894 9.78665 16.1131 9.49405L5.51171 5.52152H25.58V31.1891H29.0917C30.1963 31.1891 31.0917 30.2937 31.0917 29.1891V2C31.0917 0.895431 30.1963 0 29.0917 0H2ZM14.6022 23.7351C15.0558 23.956 15.4239 23.6812 15.4239 23.1185C15.4239 22.5557 15.0558 21.9204 14.6022 21.6995C14.1486 21.4775 13.7804 21.7545 13.7804 22.3161C13.7804 22.8777 14.1486 23.513 14.6022 23.7351Z" fill="url(#paint0_linear_1473_97)"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_1473_97" x1="31" y1="-2" x2="0.975591" y2="14.2625" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2622FF"/>
|
||||
<stop offset="1" stop-color="#A717FF"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<svg width="24" height="15" viewBox="0 0 24 15" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M24 14.8323V14.8326H14.3246L9.16716 9.67507V14.8326H0V14.8314L9.16716 5.66422V0H9.16774L24 14.8323Z" fill="black"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 1.5 KiB After Width: | Height: | Size: 228 B |
@@ -0,0 +1,41 @@
|
||||
model: Sao10K/L3-8B-Stheno-v3.2
|
||||
label:
|
||||
zh_Hans: Sao10K/L3-8B-Stheno-v3.2
|
||||
en_US: Sao10K/L3-8B-Stheno-v3.2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,40 @@
|
||||
# Deepseek Models
|
||||
- deepseek/deepseek_v3
|
||||
|
||||
# LLaMA Models
|
||||
- meta-llama/llama-3.3-70b-instruct
|
||||
- meta-llama/llama-3.2-11b-vision-instruct
|
||||
- meta-llama/llama-3.2-3b-instruct
|
||||
- meta-llama/llama-3.2-1b-instruct
|
||||
- meta-llama/llama-3.1-70b-instruct
|
||||
- meta-llama/llama-3.1-8b-instruct
|
||||
- meta-llama/llama-3.1-8b-instruct-max
|
||||
- meta-llama/llama-3.1-8b-instruct-bf16
|
||||
- meta-llama/llama-3-70b-instruct
|
||||
- meta-llama/llama-3-8b-instruct
|
||||
|
||||
# Mistral Models
|
||||
- mistralai/mistral-nemo
|
||||
- mistralai/mistral-7b-instruct
|
||||
|
||||
# Qwen Models
|
||||
- qwen/qwen-2.5-72b-instruct
|
||||
- qwen/qwen-2-72b-instruct
|
||||
- qwen/qwen-2-vl-72b-instruct
|
||||
- qwen/qwen-2-7b-instruct
|
||||
|
||||
# Other Models
|
||||
- sao10k/L3-8B-Stheno-v3.2
|
||||
- sao10k/l3-70b-euryale-v2.1
|
||||
- sao10k/l31-70b-euryale-v2.2
|
||||
- sao10k/l3-8b-lunaris
|
||||
- jondurbin/airoboros-l2-70b
|
||||
- cognitivecomputations/dolphin-mixtral-8x22b
|
||||
- google/gemma-2-9b-it
|
||||
- nousresearch/hermes-2-pro-llama-3-8b
|
||||
- sophosympatheia/midnight-rose-70b
|
||||
- gryphe/mythomax-l2-13b
|
||||
- nousresearch/nous-hermes-llama2-13b
|
||||
- openchat/openchat-7b
|
||||
- teknium/openhermes-2.5-mistral-7b
|
||||
- microsoft/wizardlm-2-8x22b
|
||||
@@ -0,0 +1,41 @@
|
||||
model: deepseek/deepseek_v3
|
||||
label:
|
||||
zh_Hans: deepseek/deepseek_v3
|
||||
en_US: deepseek/deepseek_v3
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0089'
|
||||
output: '0.0089'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: sao10k/l3-8b-lunaris
|
||||
label:
|
||||
zh_Hans: sao10k/l3-8b-lunaris
|
||||
en_US: sao10k/l3-8b-lunaris
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: sao10k/l31-70b-euryale-v2.2
|
||||
label:
|
||||
zh_Hans: sao10k/l31-70b-euryale-v2.2
|
||||
en_US: sao10k/l31-70b-euryale-v2.2
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0148'
|
||||
output: '0.0148'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.00063'
|
||||
output: '0.00063'
|
||||
input: '0.0004'
|
||||
output: '0.0004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
||||
@@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0055'
|
||||
output: '0.0076'
|
||||
input: '0.0034'
|
||||
output: '0.0039'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
model: meta-llama/llama-3.1-8b-instruct-bf16
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.1-8b-instruct-bf16
|
||||
en_US: meta-llama/llama-3.1-8b-instruct-bf16
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0006'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: meta-llama/llama-3.1-8b-instruct-max
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.1-8b-instruct-max
|
||||
en_US: meta-llama/llama-3.1-8b-instruct-max
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -7,7 +7,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.001'
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
model: meta-llama/llama-3.2-11b-vision-instruct
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.2-11b-vision-instruct
|
||||
en_US: meta-llama/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0006'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: meta-llama/llama-3.2-1b-instruct
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.2-1b-instruct
|
||||
en_US: meta-llama/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0002'
|
||||
output: '0.0002'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -1,7 +1,7 @@
|
||||
model: meta-llama/llama-3.1-405b-instruct
|
||||
model: meta-llama/llama-3.2-3b-instruct
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.1-405b-instruct
|
||||
en_US: meta-llama/llama-3.1-405b-instruct
|
||||
zh_Hans: meta-llama/llama-3.2-3b-instruct
|
||||
en_US: meta-llama/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.03'
|
||||
output: '0.05'
|
||||
input: '0.0003'
|
||||
output: '0.0005'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: meta-llama/llama-3.3-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta-llama/llama-3.3-70b-instruct
|
||||
en_US: meta-llama/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0039'
|
||||
output: '0.0039'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: mistralai/mistral-nemo
|
||||
label:
|
||||
zh_Hans: mistralai/mistral-nemo
|
||||
en_US: mistralai/mistral-nemo
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0017'
|
||||
output: '0.0017'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.00119'
|
||||
output: '0.00119'
|
||||
input: '0.0009'
|
||||
output: '0.0009'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
model: lzlv_70b
|
||||
model: openchat/openchat-7b
|
||||
label:
|
||||
zh_Hans: lzlv_70b
|
||||
en_US: lzlv_70b
|
||||
zh_Hans: openchat/openchat-7b
|
||||
en_US: openchat/openchat-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0058'
|
||||
output: '0.0078'
|
||||
input: '0.0006'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -1,7 +1,7 @@
|
||||
model: Nous-Hermes-2-Mixtral-8x7B-DPO
|
||||
model: qwen/qwen-2-72b-instruct
|
||||
label:
|
||||
zh_Hans: Nous-Hermes-2-Mixtral-8x7B-DPO
|
||||
en_US: Nous-Hermes-2-Mixtral-8x7B-DPO
|
||||
zh_Hans: qwen/qwen-2-72b-instruct
|
||||
en_US: qwen/qwen-2-72b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0027'
|
||||
output: '0.0027'
|
||||
input: '0.0034'
|
||||
output: '0.0039'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: qwen/qwen-2-7b-instruct
|
||||
label:
|
||||
zh_Hans: qwen/qwen-2-7b-instruct
|
||||
en_US: qwen/qwen-2-7b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.00054'
|
||||
output: '0.00054'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: qwen/qwen-2-vl-72b-instruct
|
||||
label:
|
||||
zh_Hans: qwen/qwen-2-vl-72b-instruct
|
||||
en_US: qwen/qwen-2-vl-72b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0045'
|
||||
output: '0.0045'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: qwen/qwen-2.5-72b-instruct
|
||||
label:
|
||||
zh_Hans: qwen/qwen-2.5-72b-instruct
|
||||
en_US: qwen/qwen-2.5-72b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 2
|
||||
default: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 2048
|
||||
default: 512
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0038'
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@@ -35,7 +35,7 @@ parameter_rules:
|
||||
max: 2
|
||||
default: 0
|
||||
pricing:
|
||||
input: '0.0064'
|
||||
output: '0.0064'
|
||||
input: '0.0062'
|
||||
output: '0.0062'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
provider: novita
|
||||
label:
|
||||
en_US: novita.ai
|
||||
en_US: Novita AI
|
||||
description:
|
||||
en_US: An LLM API that matches various application scenarios with high cost-effectiveness.
|
||||
zh_Hans: 适配多种海外应用场景的高性价比 LLM API
|
||||
@@ -11,10 +11,10 @@ icon_large:
|
||||
background: "#eadeff"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from novita.ai
|
||||
zh_Hans: 从 novita.ai 获取 API Key
|
||||
en_US: Get your API key from Novita AI
|
||||
zh_Hans: 从 Novita AI 获取 API Key
|
||||
url:
|
||||
en_US: https://novita.ai/settings#key-management?utm_source=dify&utm_medium=ch&utm_campaign=api
|
||||
en_US: https://novita.ai/settings/key-management?utm_source=dify&utm_medium=ch&utm_campaign=api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
@@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
@@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=block_result.message,
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
|
||||
@@ -29,9 +29,6 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
||||
@@ -21,7 +21,7 @@ class SparkLLMClient:
|
||||
domain = api_domain
|
||||
|
||||
model_api_configs = {
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "general"},
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "lite"},
|
||||
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
|
||||
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
|
||||
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},
|
||||
|
||||
@@ -257,8 +257,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
for index, response in enumerate(responses):
|
||||
if response.status_code not in {200, HTTPStatus.OK}:
|
||||
raise ServiceUnavailableError(
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
||||
f"message: {response.message}"
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
|
||||
)
|
||||
|
||||
resp_finish_reason = response.output.choices[0].finish_reason
|
||||
|
||||
@@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
||||
@@ -18,72 +18,93 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
"Doubao-1.5-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-1.5-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-1.5-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-1.5-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.VISION],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-vision-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.VISION],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
),
|
||||
"Doubao-pro-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[],
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Doubao-pro-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-lite-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[]
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Skylark2-pro-4k": ModelConfig(
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[]
|
||||
properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Llama3-8B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Llama3-70B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[]
|
||||
properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"Moonshot-v1-8k": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Moonshot-v1-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Moonshot-v1-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"GLM3-130B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"GLM3-130B-Fin": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Mistral-7B": ModelConfig(
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[]
|
||||
properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -118,6 +118,30 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: Doubao-1.5-vision-pro-32k
|
||||
value: Doubao-1.5-vision-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-pro-32k
|
||||
value: Doubao-1.5-pro-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-lite-32k
|
||||
value: Doubao-1.5-lite-32k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-pro-256k
|
||||
value: Doubao-1.5-pro-256k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-vision-pro-32k
|
||||
value: Doubao-vision-pro-32k
|
||||
|
||||
@@ -41,15 +41,15 @@ class BaiduAccessToken:
|
||||
resp = response.json()
|
||||
if "error" in resp:
|
||||
if resp["error"] == "invalid_client":
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
|
||||
elif resp["error"] == "unknown_error":
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
raise InternalServerError(f"Internal server error: {resp['error_description']}")
|
||||
elif resp["error"] == "invalid_request":
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
raise BadRequestError(f"Bad request: {resp['error_description']}")
|
||||
elif resp["error"] == "rate_limit_exceeded":
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
raise Exception(f"Unknown error: {resp['error_description']}")
|
||||
|
||||
return resp["access_token"]
|
||||
|
||||
|
||||
@@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
else:
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials["server_url"],
|
||||
@@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials.get("api_key") or "abc"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
base_url=f"{credentials['server_url']}/v1",
|
||||
api_key=api_key,
|
||||
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
||||
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
model: glm-4-air-0111
|
||||
label:
|
||||
en_US: glm-4-air-0111
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: do_sample
|
||||
label:
|
||||
zh_Hans: 采样策略
|
||||
en_US: Sampling strategy
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4095
|
||||
- name: web_search
|
||||
type: boolean
|
||||
label:
|
||||
zh_Hans: 联网搜索
|
||||
en_US: Web Search
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -87,6 +87,6 @@ class CommonValidator:
|
||||
if value.lower() not in {"true", "false"}:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||
|
||||
value = True if value.lower() == "true" else False
|
||||
value = value.lower() == "true"
|
||||
|
||||
return value
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
@@ -56,5 +57,36 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
return v
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
0
api/core/ops/opik_trace/__init__.py
Normal file
0
api/core/ops/opik_trace/__init__.py
Normal file
469
api/core/ops/opik_trace/opik_trace.py
Normal file
469
api/core/ops/opik_trace/opik_trace.py
Normal file
@@ -0,0 +1,469 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_dict(key_name, data):
|
||||
"""Make sure that the input data is a dict"""
|
||||
if not isinstance(data, dict):
|
||||
return {key_name: data}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def wrap_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all Traces and Spans"""
|
||||
metadata["created_from"] = "dify"
|
||||
|
||||
metadata.update(kwargs)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
it to a UUIDv7. Given that we have no way to identify that object
|
||||
uniquely, generate a new random one UUIDv7 in that case.
|
||||
"""
|
||||
|
||||
if user_datetime is None:
|
||||
user_datetime = datetime.now()
|
||||
|
||||
if user_uuid is None:
|
||||
user_uuid = str(uuid.uuid4())
|
||||
|
||||
return uuid4_to_uuid7(user_datetime, user_uuid)
|
||||
|
||||
|
||||
class OpikDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
opik_config: OpikConfig,
|
||||
):
|
||||
super().__init__(opik_config)
|
||||
self.opik_client = Opik(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
dify_trace_id = trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["message", "workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
metadata = execution_metadata.copy()
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
total_tokens = 0
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = "llm"
|
||||
provider = process_data.get("model_provider", None)
|
||||
model = process_data.get("model_name", "")
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": provider,
|
||||
"ls_model_name": model,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
total_tokens = outputs["usage"].get("total_tokens", 0)
|
||||
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
||||
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"name": node_type,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": wrap_dict("input", inputs),
|
||||
"output": wrap_dict("output", outputs),
|
||||
"tags": ["node_execution"],
|
||||
"project_name": self.project,
|
||||
"usage": {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": message_data.answer,
|
||||
"tags": ["message", str(trace_info.conversation_mode)],
|
||||
"project_name": self.project,
|
||||
}
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": "llm",
|
||||
"type": "llm",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": {"input": trace_info.inputs},
|
||||
"output": {"output": message_data.answer},
|
||||
"tags": ["llm", str(trace_info.conversation_mode)],
|
||||
"usage": {
|
||||
"completion_tokens": trace_info.answer_tokens,
|
||||
"prompt_tokens": trace_info.message_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
},
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
"tags": ["moderation"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.suggested_question),
|
||||
"tags": ["suggested_question"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {"documents": trace_info.documents},
|
||||
"tags": ["dataset_retrieval"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.tool_inputs),
|
||||
"output": wrap_dict("output", trace_info.tool_outputs),
|
||||
"tags": ["tool", trace_info.tool_name],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": trace_info.outputs,
|
||||
"tags": ["generate_name"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.outputs),
|
||||
"tags": ["generate_name"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
return trace
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
self.opik_client.auth_check()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.opik_client.get_project_url(project_name=self.project)
|
||||
except Exception as e:
|
||||
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
||||
@@ -17,6 +17,7 @@ from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
@@ -32,6 +33,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@@ -52,6 +54,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
TracingProviderEnum.OPIK.value: {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,12 @@ from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@@ -835,11 +840,18 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
else:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings: list[ModelSettings] = []
|
||||
if not provider_model_settings:
|
||||
|
||||
@@ -258,7 +258,7 @@ class LindormVectorStore(BaseVector):
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
@@ -305,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_use_hnsw = nlist > 10000
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
|
||||
@@ -57,6 +57,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
) using heap;
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
class PGVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: PGVectorConfig):
|
||||
@@ -205,7 +210,10 @@ class PGVector(BaseVector):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
if dimension <= 2000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -14,48 +14,47 @@ class FirecrawlApp:
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
json_data = {"url": url}
|
||||
def scrape_url(self, url, params=None) -> dict[str, Any]:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
"url": url,
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"timeout": 30000,
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data["success"] == True:
|
||||
data = response_data["data"]
|
||||
return {
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
"source_url": data.get("metadata").get("sourceURL"),
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
||||
data = response_data["data"]
|
||||
return self._extract_common_fields(data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "scrape URL")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
headers = self._prepare_headers()
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get("jobId")
|
||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||
job_id = response.json().get("id")
|
||||
return cast(str, job_id)
|
||||
else:
|
||||
self._handle_error(response, "start crawl job")
|
||||
# FIXME: unreachable code for mypy
|
||||
return "" # unreachable
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
|
||||
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
@@ -66,42 +65,48 @@ class FirecrawlApp:
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data = self._extract_common_fields(item)
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
return {
|
||||
"status": "completed",
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
try:
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
except Exception as e:
|
||||
raise Exception(f"Error saving crawl data: {e}")
|
||||
return self._format_crawl_status_response("completed", crawl_status_response, url_data_list)
|
||||
else:
|
||||
return {
|
||||
"status": crawl_status_response.get("status"),
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": [],
|
||||
}
|
||||
|
||||
return self._format_crawl_status_response(
|
||||
crawl_status_response.get("status"), crawl_status_response, []
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
# FIXME: unreachable code for mypy
|
||||
return {} # unreachable
|
||||
|
||||
def _prepare_headers(self):
|
||||
def _format_crawl_status_response(
|
||||
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"status": status,
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("completed"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
|
||||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
@@ -110,7 +115,7 @@ class FirecrawlApp:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
@@ -119,6 +124,6 @@ class FirecrawlApp:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
def _handle_error(self, response, action) -> None:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
|
||||
@@ -358,8 +358,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@@ -112,7 +112,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row[0], metadata={"answer": row[1]})
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
@@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
raise Exception(f"Failed to create task: {response.get('msg')}")
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
@@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate content: {response.get('msg')}")
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
@@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
@@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
@@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
@@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
@@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
key = f"{credentials['aippt_access_key']}#@#{user_id}"
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
@@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
@@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
@@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
||||
114
api/core/tools/provider/builtin/aws/tools/bedrock_config.py
Normal file
114
api/core/tools/provider/builtin/aws/tools/bedrock_config.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Configuration classes for AWS Bedrock retrieve and generate API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextInferenceConfig:
|
||||
"""Text inference configuration"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
stopSequences: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
latency: Literal["standard", "optimized"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt template configuration"""
|
||||
|
||||
textPromptTemplate: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailConfig:
|
||||
"""Guardrail configuration"""
|
||||
|
||||
guardrailId: str
|
||||
guardrailVersion: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
"""Generation configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
guardrailConfiguration: Optional[GuardrailConfig] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorSearchConfig:
|
||||
"""Vector search configuration"""
|
||||
|
||||
filter: Optional[dict[str, Any]] = None
|
||||
numberOfResults: Optional[int] = None
|
||||
overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""Retrieval configuration"""
|
||||
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationConfig:
|
||||
"""Orchestration configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeBaseConfig:
|
||||
"""Knowledge base configuration"""
|
||||
|
||||
generationConfiguration: GenerationConfig
|
||||
knowledgeBaseId: str
|
||||
modelArn: str
|
||||
orchestrationConfiguration: Optional[OrchestrationConfig] = None
|
||||
retrievalConfiguration: Optional[RetrievalConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
"""Session configuration"""
|
||||
|
||||
kmsKeyArn: Optional[str] = None
|
||||
sessionId: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfiguration:
|
||||
"""Retrieve and generate configuration
|
||||
The use of knowledgeBaseConfiguration or externalSourcesConfiguration depends on the type value
|
||||
"""
|
||||
|
||||
type: str = "KNOWLEDGE_BASE"
|
||||
knowledgeBaseConfiguration: Optional[KnowledgeBaseConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfig:
|
||||
"""Retrieve and generate main configuration"""
|
||||
|
||||
input: dict[str, str]
|
||||
retrieveAndGenerateConfiguration: RetrieveAndGenerateConfiguration
|
||||
sessionConfiguration: Optional[SessionConfig] = None
|
||||
sessionId: Optional[str] = None
|
||||
@@ -0,0 +1,324 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
|
||||
def _create_text_inference_config(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create text inference configuration"""
|
||||
if any([max_tokens, stop_sequences, temperature, top_p]):
|
||||
config = {}
|
||||
if max_tokens is not None:
|
||||
config["maxTokens"] = max_tokens
|
||||
if stop_sequences:
|
||||
try:
|
||||
config["stopSequences"] = json.loads(stop_sequences)
|
||||
except json.JSONDecodeError:
|
||||
config["stopSequences"] = []
|
||||
if temperature is not None:
|
||||
config["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
config["topP"] = top_p
|
||||
return config
|
||||
return None
|
||||
|
||||
def _create_guardrail_config(
|
||||
self,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create guardrail configuration"""
|
||||
if guardrail_id and guardrail_version:
|
||||
return {"guardrailId": guardrail_id, "guardrailVersion": guardrail_version}
|
||||
return None
|
||||
|
||||
def _create_generation_config(
|
||||
self,
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_config: Optional[dict] = None,
|
||||
text_inference_config: Optional[dict] = None,
|
||||
performance_mode: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create generation configuration"""
|
||||
config = {}
|
||||
|
||||
if additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if guardrail_config:
|
||||
config["guardrailConfiguration"] = guardrail_config
|
||||
|
||||
if text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": text_inference_config}
|
||||
|
||||
if performance_mode:
|
||||
config["performanceConfig"] = {"latency": performance_mode}
|
||||
|
||||
if prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_orchestration_config(
|
||||
self,
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_text_inference_config: Optional[dict] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create orchestration configuration"""
|
||||
config = {}
|
||||
|
||||
if orchestration_additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(orchestration_additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if orchestration_text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": orchestration_text_inference_config}
|
||||
|
||||
if orchestration_performance_mode:
|
||||
config["performanceConfig"] = {"latency": orchestration_performance_mode}
|
||||
|
||||
if orchestration_prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": orchestration_prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_vector_search_config(
|
||||
self,
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create vector search configuration"""
|
||||
config = {
|
||||
"numberOfResults": number_of_results,
|
||||
"overrideSearchType": search_type,
|
||||
}
|
||||
|
||||
# Only add filter if metadata_filter is not empty
|
||||
if metadata_filter:
|
||||
config["filter"] = metadata_filter
|
||||
|
||||
return config
|
||||
|
||||
def _bedrock_retrieve_and_generate(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_id: str,
|
||||
model_arn: str,
|
||||
# Generation Configuration
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
performance_mode: str = "standard",
|
||||
prompt_template: Optional[str] = None,
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_max_tokens: Optional[int] = None,
|
||||
orchestration_stop_sequences: Optional[str] = None,
|
||||
orchestration_temperature: Optional[float] = None,
|
||||
orchestration_top_p: Optional[float] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
# Retrieval Configuration
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
# Additional Configuration
|
||||
session_id: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
# Create text inference configurations
|
||||
text_inference_config = self._create_text_inference_config(max_tokens, stop_sequences, temperature, top_p)
|
||||
orchestration_text_inference_config = self._create_text_inference_config(
|
||||
orchestration_max_tokens, orchestration_stop_sequences, orchestration_temperature, orchestration_top_p
|
||||
)
|
||||
|
||||
# Create guardrail configuration
|
||||
guardrail_config = self._create_guardrail_config(guardrail_id, guardrail_version)
|
||||
|
||||
# Create vector search configuration
|
||||
vector_search_config = self._create_vector_search_config(number_of_results, search_type, metadata_filter)
|
||||
|
||||
# Create generation configuration
|
||||
generation_config = self._create_generation_config(
|
||||
additional_model_fields, guardrail_config, text_inference_config, performance_mode, prompt_template
|
||||
)
|
||||
|
||||
# Create orchestration configuration
|
||||
orchestration_config = self._create_orchestration_config(
|
||||
orchestration_additional_model_fields,
|
||||
orchestration_text_inference_config,
|
||||
orchestration_performance_mode,
|
||||
orchestration_prompt_template,
|
||||
)
|
||||
|
||||
# Create knowledge base configuration
|
||||
knowledge_base_config = {
|
||||
"knowledgeBaseId": knowledge_base_id,
|
||||
"modelArn": model_arn,
|
||||
"generationConfiguration": generation_config,
|
||||
"orchestrationConfiguration": orchestration_config,
|
||||
"retrievalConfiguration": {"vectorSearchConfiguration": vector_search_config},
|
||||
}
|
||||
|
||||
# Create request configuration
|
||||
request_config = {
|
||||
"input": {"text": query},
|
||||
"retrieveAndGenerateConfiguration": {
|
||||
"type": "KNOWLEDGE_BASE",
|
||||
"knowledgeBaseConfiguration": knowledge_base_config,
|
||||
},
|
||||
}
|
||||
|
||||
# Add session configuration if provided
|
||||
if session_id and len(session_id) >= 2:
|
||||
request_config["sessionConfiguration"] = {"sessionId": session_id}
|
||||
request_config["sessionId"] = session_id
|
||||
|
||||
# Send request
|
||||
response = self.bedrock_client.retrieve_and_generate(**request_config)
|
||||
|
||||
# Process response
|
||||
result = {"output": response.get("output", {}).get("text", ""), "citations": []}
|
||||
|
||||
# Process citations
|
||||
for citation in response.get("citations", []):
|
||||
citation_info = {
|
||||
"text": citation.get("generatedResponsePart", {}).get("textResponsePart", {}).get("text", ""),
|
||||
"references": [],
|
||||
}
|
||||
|
||||
for ref in citation.get("retrievedReferences", []):
|
||||
reference = {
|
||||
"content": ref.get("content", {}).get("text", ""),
|
||||
"metadata": ref.get("metadata", {}),
|
||||
"location": None,
|
||||
}
|
||||
|
||||
location = ref.get("location", {})
|
||||
if location.get("type") == "S3":
|
||||
reference["location"] = location.get("s3Location", {}).get("uri")
|
||||
|
||||
citation_info["references"].append(reference)
|
||||
|
||||
result["citations"].append(citation_info)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Bedrock service: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {
|
||||
"service_name": "bedrock-agent-runtime",
|
||||
}
|
||||
if aws_region:
|
||||
client_kwargs["region_name"] = aws_region
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
try:
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
# Parse metadata filter if provided
|
||||
metadata_filter = None
|
||||
if metadata_filter_str := tool_parameters.get("metadata_filter"):
|
||||
try:
|
||||
parsed_filter = json.loads(metadata_filter_str)
|
||||
if parsed_filter: # Only set if not empty
|
||||
metadata_filter = parsed_filter
|
||||
except json.JSONDecodeError:
|
||||
return self.create_text_message("metadata_filter must be a valid JSON string")
|
||||
|
||||
try:
|
||||
response = self._bedrock_retrieve_and_generate(
|
||||
query=tool_parameters["query"],
|
||||
knowledge_base_id=tool_parameters["knowledge_base_id"],
|
||||
model_arn=tool_parameters["model_arn"],
|
||||
# Generation Configuration
|
||||
additional_model_fields=tool_parameters.get("additional_model_fields"),
|
||||
guardrail_id=tool_parameters.get("guardrail_id"),
|
||||
guardrail_version=tool_parameters.get("guardrail_version"),
|
||||
max_tokens=tool_parameters.get("max_tokens"),
|
||||
stop_sequences=tool_parameters.get("stop_sequences"),
|
||||
temperature=tool_parameters.get("temperature"),
|
||||
top_p=tool_parameters.get("top_p"),
|
||||
performance_mode=tool_parameters.get("performance_mode", "standard"),
|
||||
prompt_template=tool_parameters.get("prompt_template"),
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields=tool_parameters.get("orchestration_additional_model_fields"),
|
||||
orchestration_max_tokens=tool_parameters.get("orchestration_max_tokens"),
|
||||
orchestration_stop_sequences=tool_parameters.get("orchestration_stop_sequences"),
|
||||
orchestration_temperature=tool_parameters.get("orchestration_temperature"),
|
||||
orchestration_top_p=tool_parameters.get("orchestration_top_p"),
|
||||
orchestration_performance_mode=tool_parameters.get("orchestration_performance_mode"),
|
||||
orchestration_prompt_template=tool_parameters.get("orchestration_prompt_template"),
|
||||
# Retrieval Configuration
|
||||
number_of_results=tool_parameters.get("number_of_results", 5),
|
||||
search_type=tool_parameters.get("search_type", "SEMANTIC"),
|
||||
metadata_filter=metadata_filter,
|
||||
# Additional Configuration
|
||||
session_id=tool_parameters.get("session_id"),
|
||||
)
|
||||
return self.create_json_message(response)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool execution error: {str(e)}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""Validate the parameters"""
|
||||
required_params = ["query", "model_arn", "knowledge_base_id"]
|
||||
for param in required_params:
|
||||
if not parameters.get(param):
|
||||
raise ValueError(f"{param} is required")
|
||||
|
||||
# Validate metadata filter if provided
|
||||
if metadata_filter_str := parameters.get("metadata_filter"):
|
||||
try:
|
||||
if not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("metadata_filter must be a valid JSON string")
|
||||
@@ -0,0 +1,358 @@
|
||||
identity:
|
||||
name: bedrock_retrieve_and_generate
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve and Generate
|
||||
zh_Hans: Bedrock检索和生成
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
zh_Hans: 使用Amazon Bedrock知识库进行信息检索和生成的工具
|
||||
llm: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
|
||||
parameters:
|
||||
# Additional Configuration
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: Optional session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的可选会话ID
|
||||
form: form
|
||||
|
||||
# AWS Configuration
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS区域
|
||||
human_description:
|
||||
en_US: AWS region for the Bedrock service
|
||||
zh_Hans: Bedrock服务的AWS区域
|
||||
form: form
|
||||
|
||||
- name: aws_access_key_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Access Key ID
|
||||
zh_Hans: AWS访问密钥ID
|
||||
human_description:
|
||||
en_US: AWS access key ID for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS访问密钥ID(可选)
|
||||
form: form
|
||||
|
||||
- name: aws_secret_access_key
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Secret Access Key
|
||||
zh_Hans: AWS秘密访问密钥
|
||||
human_description:
|
||||
en_US: AWS secret access key for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
# Knowledge Base Configuration
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base
|
||||
zh_Hans: Bedrock知识库的ID
|
||||
form: form
|
||||
|
||||
- name: model_arn
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model ARN
|
||||
zh_Hans: 模型ARN
|
||||
human_description:
|
||||
en_US: The ARN of the model to use
|
||||
zh_Hans: 要使用的模型ARN
|
||||
form: form
|
||||
|
||||
# Retrieval Configuration
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: The search query to retrieve information
|
||||
zh_Hans: 用于检索信息的查询语句
|
||||
form: llm
|
||||
|
||||
- name: number_of_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Number of Results
|
||||
zh_Hans: 结果数量
|
||||
human_description:
|
||||
en_US: Number of results to retrieve (1-10)
|
||||
zh_Hans: 要检索的结果数量(1-10)
|
||||
default: 5
|
||||
min: 1
|
||||
max: 10
|
||||
form: form
|
||||
|
||||
- name: search_type
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Type
|
||||
zh_Hans: 搜索类型
|
||||
human_description:
|
||||
en_US: Type of search to perform
|
||||
zh_Hans: 要执行的搜索类型
|
||||
default: SEMANTIC
|
||||
options:
|
||||
- value: SEMANTIC
|
||||
label:
|
||||
en_US: Semantic Search
|
||||
zh_Hans: 语义搜索
|
||||
- value: HYBRID
|
||||
label:
|
||||
en_US: Hybrid Search
|
||||
zh_Hans: 混合搜索
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
human_description:
|
||||
en_US: JSON formatted filter conditions for metadata, supporting operations like equals, greaterThan, lessThan, etc.
|
||||
zh_Hans: 元数据的JSON格式过滤条件,支持等于、大于、小于等操作
|
||||
default: "{}"
|
||||
form: form
|
||||
|
||||
# Generation Configuration
|
||||
- name: guardrail_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: 防护栏ID
|
||||
human_description:
|
||||
en_US: ID of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏ID
|
||||
form: form
|
||||
|
||||
- name: guardrail_version
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail Version
|
||||
zh_Hans: 防护栏版本
|
||||
human_description:
|
||||
en_US: Version of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏版本
|
||||
form: form
|
||||
|
||||
- name: max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum Tokens
|
||||
zh_Hans: 最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens to generate
|
||||
zh_Hans: 生成的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: stop_sequences
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Stop Sequences
|
||||
zh_Hans: 停止序列
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop generation when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止生成
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Temperature
|
||||
zh_Hans: 温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the output (0-1)
|
||||
zh_Hans: 控制输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top P
|
||||
zh_Hans: Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling (0-1)
|
||||
zh_Hans: 通过核采样控制多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Performance Mode
|
||||
zh_Hans: 性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode(performanceConfig.latency)
|
||||
zh_Hans: 选择性能优化模式(performanceConfig.latency)
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Prompt Template
|
||||
zh_Hans: 提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for generation
|
||||
zh_Hans: 用于生成的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Additional Model Fields
|
||||
zh_Hans: 额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for model configuration
|
||||
zh_Hans: JSON格式的额外模型配置字段
|
||||
default: "{}"
|
||||
form: form
|
||||
|
||||
# Orchestration Configuration
|
||||
- name: orchestration_max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Maximum Tokens
|
||||
zh_Hans: 编排最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens for orchestration
|
||||
zh_Hans: 编排过程的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: orchestration_stop_sequences
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Stop Sequences
|
||||
zh_Hans: 编排停止序列
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop orchestration when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止编排
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: orchestration_temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Temperature
|
||||
zh_Hans: 编排温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the orchestration output (0-1)
|
||||
zh_Hans: 控制编排输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Top P
|
||||
zh_Hans: 编排Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling in orchestration (0-1)
|
||||
zh_Hans: 通过核采样控制编排的多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Performance Mode
|
||||
zh_Hans: 编排性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode for orchestration
|
||||
zh_Hans: 选择编排的性能优化模式
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: orchestration_prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Prompt Template
|
||||
zh_Hans: 编排提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for orchestration
|
||||
zh_Hans: 用于编排的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: orchestration_additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Additional Model Fields
|
||||
zh_Hans: 编排额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for orchestration model configuration
|
||||
zh_Hans: JSON格式的编排模型额外配置字段
|
||||
default: "{}"
|
||||
form: form
|
||||
@@ -229,8 +229,7 @@ class NovaReelTool(BuiltinTool):
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
@@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
||||
@@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
|
||||
result_text = ""
|
||||
if result["error_code"] != 0:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
else:
|
||||
result_text = result["data"]["src"]
|
||||
result_text = self.mapping_result(description_language, result_text)
|
||||
|
||||
@@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user