mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 09:44:00 +00:00
Compare commits
95 Commits
feat/retry
...
provider-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df82bccfdf | ||
|
|
5ba875f96b | ||
|
|
409cc7d9b0 | ||
|
|
fe26be2312 | ||
|
|
34519de3b7 | ||
|
|
147d578922 | ||
|
|
9c317b64c3 | ||
|
|
3b8f6233b0 | ||
|
|
455b0cd696 | ||
|
|
1fa66405c5 | ||
|
|
b680a85b57 | ||
|
|
682ebc5f64 | ||
|
|
b8ba39dfae | ||
|
|
6c9e6a3a5a | ||
|
|
70698024f5 | ||
|
|
6df17a334c | ||
|
|
a5fb59b17f | ||
|
|
7ed6485f86 | ||
|
|
478150e850 | ||
|
|
3c2e30f348 | ||
|
|
b873e6349c | ||
|
|
2b1a32fd9c | ||
|
|
a2105634a4 | ||
|
|
7c71bd7be7 | ||
|
|
7c1961e618 | ||
|
|
baeddd4d15 | ||
|
|
6f5a8a33d9 | ||
|
|
52b2559a14 | ||
|
|
3d150c30a7 | ||
|
|
e58e573f3e | ||
|
|
375aa38f5d | ||
|
|
0e6317678f | ||
|
|
e7dffcd0f6 | ||
|
|
065304d175 | ||
|
|
15f43dd326 | ||
|
|
09d759d196 | ||
|
|
68757950ce | ||
|
|
3c45bdf18a | ||
|
|
c135967e59 | ||
|
|
f71af7c2a8 | ||
|
|
5b01eb9437 | ||
|
|
2e716f80d2 | ||
|
|
d7c0bc8c23 | ||
|
|
f30bf08580 | ||
|
|
a640803fc9 | ||
|
|
9954ddb780 | ||
|
|
b218df6920 | ||
|
|
5b6950e545 | ||
|
|
c7911c7130 | ||
|
|
62f792ea14 | ||
|
|
6a85960605 | ||
|
|
63a0b8ba79 | ||
|
|
634b382a3d | ||
|
|
fbf5deda21 | ||
|
|
d4b848272e | ||
|
|
fc29f2003e | ||
|
|
ab469aa07d | ||
|
|
562450751f | ||
|
|
adacd01f82 | ||
|
|
74d3320519 | ||
|
|
309a15d1ba | ||
|
|
bcef11681d | ||
|
|
8d15c8cfbf | ||
|
|
716bb8574d | ||
|
|
bd2fec4813 | ||
|
|
ead4b34127 | ||
|
|
72ae414da4 | ||
|
|
4c9618be3f | ||
|
|
901028f1e8 | ||
|
|
adfbfc1255 | ||
|
|
b66c03dfe9 | ||
|
|
2a909e634b | ||
|
|
9d86056f1c | ||
|
|
309fd76ddf | ||
|
|
a3293b154e | ||
|
|
eb8963a673 | ||
|
|
89ce9a5db2 | ||
|
|
f4f2567105 | ||
|
|
5a3fe61f2a | ||
|
|
55c327ffcb | ||
|
|
0fdb39f1c3 | ||
|
|
dae1b5a619 | ||
|
|
26b5680913 | ||
|
|
a2855fa24a | ||
|
|
9c3cf7b69a | ||
|
|
be7877f526 | ||
|
|
e765d8e69e | ||
|
|
4bd8df1fd3 | ||
|
|
4e76f2fc44 | ||
|
|
cf00ee42f5 | ||
|
|
886758d2be | ||
|
|
8339d2c7c9 | ||
|
|
811e4bd0cf | ||
|
|
49feff082f | ||
|
|
efdd54a670 |
@@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# Refresh token expiration time in days
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
|
||||
@@ -85,11 +85,11 @@ ignore = [
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"F401", # unused-import
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
extend-generics = [
|
||||
allowed-unused-imports = [
|
||||
"_pytest.monkeypatch",
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
@@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.19+dfsg-1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
||||
29
api/app.py
29
api/app.py
@@ -1,12 +1,8 @@
|
||||
from libs import version_utils
|
||||
|
||||
# preparation before creating app
|
||||
version_utils.check_supported_python_version()
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command():
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
@@ -18,10 +14,25 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
from app_factory import create_app
|
||||
from libs import threadings_utils
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey # type: ignore
|
||||
|
||||
threadings_utils.apply_gevent_threading_patch()
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
@@ -488,6 +488,11 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
|
||||
description="Expiration time for refresh tokens in days",
|
||||
default=30,
|
||||
)
|
||||
|
||||
LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
|
||||
default=86400,
|
||||
@@ -601,7 +606,7 @@ class RagEtlConfig(BaseSettings):
|
||||
|
||||
UNSTRUCTURED_API_KEY: Optional[str] = Field(
|
||||
description="API key for Unstructured.io service",
|
||||
default=None,
|
||||
default="",
|
||||
)
|
||||
|
||||
SCARF_NO_ANALYTICS: Optional[str] = Field(
|
||||
@@ -667,6 +672,11 @@ class IndexingConfig(BaseSettings):
|
||||
default=4000,
|
||||
)
|
||||
|
||||
CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
|
||||
description="Maximum number of child chunks to preview",
|
||||
default=50,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransferConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
|
||||
@@ -765,6 +775,13 @@ class LoginConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class AccountConfig(BaseSettings):
|
||||
ACCOUNT_DELETION_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a account deletion token remains valid",
|
||||
default=5,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -792,6 +809,7 @@ class FeatureConfig(
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@@ -57,12 +57,13 @@ class AppListApi(Resource):
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@@ -76,7 +75,7 @@ class CompletionMessageApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@@ -141,7 +140,7 @@ class ChatMessageApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -273,8 +273,7 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -14,7 +14,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
@@ -440,29 +440,29 @@ class WorkflowConfigApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
class DraftWorkflowNodeRetriableApi(Resource):
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Run draft workflow node
|
||||
Get published workflows
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
page = args.get("page")
|
||||
limit = args.get("limit")
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_retriable_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs", {}), account=current_user
|
||||
)
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)
|
||||
|
||||
return workflow_node_execution
|
||||
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
@@ -479,9 +479,9 @@ api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
)
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
api.add_resource(DraftWorkflowNodeRetriableApi, "/apps/<uuid:app_id>/workflows/draft/retry/nodes/<string:node_id>/run")
|
||||
|
||||
@@ -53,3 +53,9 @@ class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_login_rate_limit_exceeded"
|
||||
description = "Too many login emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_account_deletion_rate_limit_exceeded"
|
||||
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
@@ -6,13 +6,8 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -20,6 +15,7 @@ from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -129,6 +125,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
@@ -16,6 +17,7 @@ from controllers.console.auth.error import (
|
||||
)
|
||||
from controllers.console.error import (
|
||||
AccountBannedError,
|
||||
AccountInFreezeError,
|
||||
AccountNotFound,
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
@@ -26,6 +28,8 @@ from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -44,6 +48,9 @@ class LoginApi(Resource):
|
||||
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
@@ -113,8 +120,10 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
@@ -142,8 +151,11 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
@@ -177,7 +189,10 @@ class EmailCodeLoginApi(Resource):
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
@@ -196,6 +211,8 @@ class EmailCodeLoginApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
@@ -16,7 +16,7 @@ from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -99,6 +99,8 @@ class OAuthCallback(Resource):
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
except AccountRegisterError as e:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
|
||||
@@ -92,3 +92,12 @@ class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
code = 401
|
||||
|
||||
|
||||
class AccountInFreezeError(BaseHTTPException):
|
||||
error_code = "account_in_freeze"
|
||||
code = 400
|
||||
description = (
|
||||
"This email account has been deleted within the past 30 days"
|
||||
"and is temporarily unavailable for new account registration."
|
||||
)
|
||||
|
||||
@@ -18,7 +18,11 @@ from controllers.console.explore.error import NotChatAppError, NotCompletionAppE
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
|
||||
@@ -66,10 +66,17 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content"))
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@@ -13,7 +13,11 @@ from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
|
||||
@@ -11,6 +11,7 @@ from controllers.console import api
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidAccountDeletionCodeError,
|
||||
InvalidInvitationCodeError,
|
||||
RepeatPasswordNotMatchError,
|
||||
)
|
||||
@@ -21,6 +22,7 @@ from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@@ -242,6 +244,54 @@ class AccountIntegrateApi(Resource):
|
||||
return {"data": integrate_data}
|
||||
|
||||
|
||||
class AccountDeleteVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
account = current_user
|
||||
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_deletion_verification_email(account, code)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
|
||||
raise InvalidAccountDeletionCodeError()
|
||||
|
||||
AccountService.delete_account(account)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("feedback", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
# Register API resources
|
||||
api.add_resource(AccountInitApi, "/account/init")
|
||||
api.add_resource(AccountProfileApi, "/account/profile")
|
||||
@@ -252,5 +302,8 @@ api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
|
||||
api.add_resource(AccountTimezoneApi, "/account/timezone")
|
||||
api.add_resource(AccountPasswordApi, "/account/password")
|
||||
api.add_resource(AccountIntegrateApi, "/account/integrates")
|
||||
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
|
||||
api.add_resource(AccountDeleteApi, "/account/delete")
|
||||
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
|
||||
# api.add_resource(AccountEmailApi, '/account/email')
|
||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||
|
||||
@@ -35,3 +35,9 @@ class AccountNotInitializedError(BaseHTTPException):
|
||||
error_code = "account_not_initialized"
|
||||
description = "The account has not been initialized yet. Please proceed with the initialization process first."
|
||||
code = 400
|
||||
|
||||
|
||||
class InvalidAccountDeletionCodeError(BaseHTTPException):
|
||||
error_code = "invalid_account_deletion_code"
|
||||
description = "Invalid account deletion code."
|
||||
code = 400
|
||||
|
||||
@@ -122,7 +122,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if member:
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
try:
|
||||
|
||||
@@ -18,7 +18,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@@ -74,7 +73,7 @@ class CompletionApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
@@ -133,7 +132,7 @@ class ChatApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -108,7 +108,13 @@ class MessageFeedbackApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content"))
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
AppInvokeQuotaExceededError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
@@ -94,7 +93,7 @@ class WorkflowRunApi(Resource):
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
|
||||
@@ -8,12 +8,16 @@ from werkzeug.exceptions import NotFound
|
||||
import services.dataset_service
|
||||
from controllers.common.errors import FilenameNotExistsError
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
ProviderNotInitializeError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
@@ -186,7 +190,10 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
@@ -238,14 +245,22 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
@@ -19,7 +19,11 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
|
||||
@@ -14,7 +14,11 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
|
||||
@@ -339,13 +339,13 @@ class BaseAgentRunner(AppRunner):
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
|
||||
if thought is not None:
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
|
||||
if tool_name is not None:
|
||||
if tool_name:
|
||||
agent_thought.tool = tool_name
|
||||
|
||||
if tool_input is not None:
|
||||
if tool_input:
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
@@ -354,7 +354,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
|
||||
if observation is not None:
|
||||
if observation:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
@@ -363,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
if answer:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
|
||||
@@ -21,7 +21,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from extensions.ext_database import db
|
||||
@@ -336,7 +336,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -67,24 +67,17 @@ from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
@@ -96,7 +89,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream: bool,
|
||||
dialogue_count: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@@ -113,32 +106,35 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
raise NotImplementedError(f"User type not supported: {type(user)}")
|
||||
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation.id,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
self._message_created_at = int(message.created_at.timestamp())
|
||||
self._conversation_name_generate_thread: Thread | None = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
self._workflow_run_id = ""
|
||||
self._workflow_run_id: str = ""
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -146,13 +142,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@@ -269,24 +265,26 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
with Session(db.engine) as session:
|
||||
err = self._handle_error(event=event, session=session, message_id=self._message_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
@@ -297,7 +295,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_run.id
|
||||
workflow_start_resp = self._workflow_start_to_stream_response(
|
||||
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@@ -310,12 +308,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_retry_resp = self._workflow_node_retry_to_stream_response(
|
||||
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -329,13 +329,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
node_start_resp = self._workflow_node_start_to_stream_response(
|
||||
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -348,12 +350,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
self._recorded_files.extend(
|
||||
self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -364,10 +370,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session, event=event
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_node_finish_to_stream_response(
|
||||
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -381,13 +389,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@@ -395,13 +407,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@@ -409,9 +425,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -423,9 +441,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -437,9 +457,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -454,8 +476,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -466,21 +488,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -491,21 +515,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -517,20 +543,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
err = self._base_task_pipeline._handle_error(
|
||||
event=err_event, session=session, message_id=self._message_id
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._error_to_stream_response(err)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if self._workflow_run_id and graph_runtime_state:
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -541,7 +569,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -555,18 +583,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
@@ -587,23 +615,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
|
||||
self._task_state.answer
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
|
||||
# Save message
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
|
||||
session.commit()
|
||||
|
||||
@@ -621,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
@@ -685,20 +717,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
if self._base_task_pipeline._output_moderation_handler:
|
||||
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
self._base_task_pipeline._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
self._base_task_pipeline._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -245,7 +245,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -237,7 +237,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -17,7 +17,7 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -214,7 +214,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@@ -221,6 +221,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@@ -270,7 +271,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -58,7 +58,6 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@@ -66,16 +65,11 @@ from models.workflow import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage):
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
@@ -84,7 +78,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
self._base_task_pipeline = BasedGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream,
|
||||
@@ -101,19 +95,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}")
|
||||
|
||||
self._workflow_cycle_manager = WorkflowCycleManage(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_system_variables={
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@@ -122,7 +118,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
if self._base_task_pipeline._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
@@ -237,29 +233,29 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._queue_manager.listen():
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event=event)
|
||||
yield self._error_to_stream_response(err)
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start(
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
user_id=self._user_id,
|
||||
created_by_role=self._created_by_role,
|
||||
)
|
||||
self._workflow_run_id = workflow_run.id
|
||||
start_resp = self._workflow_start_to_stream_response(
|
||||
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@@ -271,12 +267,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -290,12 +288,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
|
||||
session=session, workflow_run=workflow_run, event=event
|
||||
)
|
||||
node_start_response = self._workflow_node_start_to_stream_response(
|
||||
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -306,9 +306,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
|
||||
node_success_response = self._workflow_node_finish_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
|
||||
session=session, event=event
|
||||
)
|
||||
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -319,12 +321,12 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
with Session(db.engine) as session:
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_node_finish_to_stream_response(
|
||||
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
|
||||
session=session,
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -339,13 +341,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_start_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
@@ -354,13 +360,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
parallel_finish_resp = (
|
||||
self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
@@ -369,9 +379,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_start_resp = self._workflow_iteration_start_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -384,9 +396,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_next_resp = self._workflow_iteration_next_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -399,9 +413,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
|
||||
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -416,8 +432,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -431,7 +447,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
@@ -445,8 +461,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_partial_success(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -461,7 +477,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
@@ -473,8 +489,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
|
||||
session=session,
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
@@ -492,7 +508,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
|
||||
|
||||
workflow_finish_resp = self._workflow_finish_to_stream_response(
|
||||
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
|
||||
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_run_id: Optional[str] = None
|
||||
workflow_run_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
TaskState,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
@@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._start_at = time.perf_counter()
|
||||
|
||||
@@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
||||
@@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -58,13 +57,20 @@ from models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
|
||||
from .exc import WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
@@ -102,7 +108,8 @@ class WorkflowCycleManage:
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
@@ -239,7 +246,7 @@ class WorkflowCycleManage:
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
@@ -247,16 +254,18 @@ class WorkflowCycleManage:
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
|
||||
running_workflow_node_executions = session.scalars(stmt).all()
|
||||
ids = session.scalars(stmt).all()
|
||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||
running_workflow_node_executions = [
|
||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||
]
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -274,7 +283,7 @@ class WorkflowCycleManage:
|
||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = event.node_execution_id
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
@@ -298,6 +307,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(
|
||||
@@ -325,6 +336,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
@@ -364,6 +376,7 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
@@ -391,7 +404,7 @@ class WorkflowCycleManage:
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = event.node_execution_id
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
@@ -415,6 +428,8 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
@@ -811,22 +826,20 @@ class WorkflowCycleManage:
|
||||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id)
|
||||
workflow_node_execution = session.scalar(stmt)
|
||||
if not workflow_node_execution:
|
||||
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return cached_workflow_node_execution
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
_executor = ProcessPoolExecutor(max_workers=1)
|
||||
|
||||
|
||||
class GPT2Tokenizer:
|
||||
@@ -20,7 +22,9 @@ class GPT2Tokenizer:
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
|
||||
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
|
||||
result = future.result()
|
||||
return cast(int, result)
|
||||
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
|
||||
@@ -24,8 +24,5 @@ class GiteeAIEmbeddingModel(OAICompatEmbeddingModel):
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
|
||||
if model is None:
|
||||
model = "bge-m3"
|
||||
|
||||
credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"
|
||||
def _add_custom_parameters(credentials: dict, model: str) -> None:
|
||||
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
|
||||
|
||||
@@ -9,6 +9,8 @@ supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
- tts
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@@ -118,3 +120,19 @@ model_credential_schema:
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: voices
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: tts
|
||||
label:
|
||||
en_US: Available Voices (comma-separated)
|
||||
zh_Hans: 可用声音(用英文逗号分隔)
|
||||
type: text-input
|
||||
required: false
|
||||
default: "Chinese Female"
|
||||
placeholder:
|
||||
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
|
||||
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
|
||||
help:
|
||||
en_US: "List voice names separated by commas. First voice will be used as default."
|
||||
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
@@ -24,9 +22,10 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(
|
||||
model,
|
||||
credentials,
|
||||
compatible_credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
@@ -36,10 +35,15 @@ class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
credentials["mode"] = "chat"
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import IO, Optional
|
||||
|
||||
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel
|
||||
|
||||
|
||||
class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
|
||||
"""
|
||||
Model class for GPUStack Speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(model, compatible_credentials, file)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Get compatible credentials
|
||||
|
||||
:param credentials: model credentials
|
||||
:return: compatible credentials
|
||||
"""
|
||||
compatible_credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return compatible_credentials
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
TextEmbeddingResult,
|
||||
@@ -24,12 +22,15 @@ class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(model, compatible_credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
return credentials
|
||||
|
||||
57
api/core/model_runtime/model_providers/gpustack/tts/tts.py
Normal file
57
api/core/model_runtime/model_providers/gpustack/tts/tts.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel
|
||||
|
||||
|
||||
class GPUStackText2SpeechModel(OAICompatText2SpeechModel):
|
||||
"""
|
||||
Model class for GPUStack Text to Speech model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
return super()._invoke(
|
||||
model=model,
|
||||
tenant_id=tenant_id,
|
||||
credentials=compatible_credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
user=user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
"""
|
||||
compatible_credentials = self._get_compatible_credentials(credentials)
|
||||
super().validate_credentials(model, compatible_credentials)
|
||||
|
||||
def _get_compatible_credentials(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Get compatible credentials
|
||||
|
||||
:param credentials: model credentials
|
||||
:return: compatible credentials
|
||||
"""
|
||||
compatible_credentials = credentials.copy()
|
||||
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
|
||||
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
|
||||
|
||||
return compatible_credentials
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -6,6 +6,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
@@ -19,6 +20,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
@@ -18,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -19,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -19,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -19,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -19,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
@@ -18,6 +19,18 @@ parameter_rules:
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
@@ -18,6 +19,18 @@ parameter_rules:
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "0.05"
|
||||
output: "0.1"
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.20'
|
||||
output: '0.20'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.7'
|
||||
output: '0.8'
|
||||
|
||||
@@ -18,6 +18,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.59'
|
||||
output: '0.79'
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
@@ -18,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.08'
|
||||
|
||||
@@ -5,6 +5,7 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
@@ -18,6 +19,18 @@ parameter_rules:
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.08'
|
||||
|
||||
@@ -54,6 +54,7 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
"Model": model,
|
||||
"Messages": messages_dict,
|
||||
"Stream": stream,
|
||||
"Stop": stop,
|
||||
**custom_parameters,
|
||||
}
|
||||
# add Tools and ToolChoice
|
||||
|
||||
@@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@@ -37,6 +37,9 @@ parameter_rules:
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '2.50'
|
||||
output: '10.00'
|
||||
|
||||
@@ -739,6 +739,12 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
delta = chunk.choices[0]
|
||||
has_finish_reason = delta.finish_reason is not None
|
||||
# to fix issue #12215 yi model has special case for ligthing
|
||||
# FIXME drop the case when yi model is updated
|
||||
if model.startswith("yi-"):
|
||||
if isinstance(delta.finish_reason, str):
|
||||
# doc: https://platform.lingyiwanwu.com/docs/api-reference
|
||||
has_finish_reason = delta.finish_reason.startswith(("length", "stop", "content_filter"))
|
||||
|
||||
if (
|
||||
not has_finish_reason
|
||||
|
||||
@@ -332,6 +332,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
elif "json_schema" in model_parameters:
|
||||
del model_parameters["json_schema"]
|
||||
|
||||
data = {"model": model, "stream": stream, **model_parameters}
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
@@ -462,7 +479,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
- Tencent/Hunyuan-A52B-Instruct
|
||||
- Qwen/QwQ-32B-Preview
|
||||
- Qwen/Qwen2.5-72B-Instruct
|
||||
- Qwen/Qwen2.5-32B-Instruct
|
||||
@@ -6,11 +5,9 @@
|
||||
- Qwen/Qwen2.5-7B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-32B-Instruct
|
||||
- Qwen/Qwen2.5-Coder-7B-Instruct
|
||||
- Qwen/Qwen2.5-Math-72B-Instruct
|
||||
- Qwen/Qwen2-VL-72B-Instruct
|
||||
- Qwen/Qwen2-1.5B-Instruct
|
||||
- Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
- OpenGVLab/InternVL2-Llama3-76B
|
||||
- OpenGVLab/InternVL2-26B
|
||||
- Pro/OpenGVLab/InternVL2-8B
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
|
||||
@@ -82,3 +82,4 @@ pricing:
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@@ -82,3 +82,4 @@ pricing:
|
||||
output: '21'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
model: Qwen/QVQ-72B-Preview
|
||||
label:
|
||||
en_US: Qwen/QVQ-72B-Preview
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 16384
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '9.90'
|
||||
output: '9.90'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@@ -15,9 +15,9 @@ parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
|
||||
@@ -78,7 +78,7 @@ parameter_rules:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
|
||||
@@ -78,7 +78,7 @@ parameter_rules:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '21'
|
||||
output: '21'
|
||||
input: '0.35'
|
||||
output: '0.35'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
|
||||
@@ -82,3 +82,4 @@ pricing:
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
deprecated: true
|
||||
|
||||
@@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data: ")
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
|
||||
@@ -122,6 +122,7 @@ class _CommonWenxin:
|
||||
"bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh",
|
||||
"tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k",
|
||||
"bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base",
|
||||
"ernie-lite-pro-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-pro-128k",
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
model: ernie-lite-pro-128k
|
||||
label:
|
||||
en_US: Ernie-Lite-Pro-128K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: min_output_tokens
|
||||
label:
|
||||
en_US: "Min Output Tokens"
|
||||
zh_Hans: "最小输出Token数"
|
||||
use_template: max_tokens
|
||||
min: 2
|
||||
max: 2048
|
||||
help:
|
||||
zh_Hans: 指定模型最小输出token数
|
||||
en_US: Specifies the lower limit on the length of generated results.
|
||||
- name: max_output_tokens
|
||||
label:
|
||||
en_US: "Max Output Tokens"
|
||||
zh_Hans: "最大输出Token数"
|
||||
use_template: max_tokens
|
||||
min: 2
|
||||
max: 2048
|
||||
default: 2048
|
||||
help:
|
||||
zh_Hans: 指定模型最大输出token数
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.001'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.01'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0'
|
||||
output: '0'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.1'
|
||||
output: '0.1'
|
||||
|
||||
@@ -49,6 +49,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.001'
|
||||
|
||||
@@ -47,6 +47,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.05'
|
||||
|
||||
@@ -45,6 +45,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.05'
|
||||
|
||||
@@ -45,6 +45,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
||||
@@ -46,6 +46,18 @@ parameter_rules:
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.01'
|
||||
output: '0.01'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -188,6 +189,23 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
model_parameters["tools"] = [web_search_params]
|
||||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
elif "json_schema" in model_parameters:
|
||||
del model_parameters["json_schema"]
|
||||
|
||||
if model.startswith("glm-4v"):
|
||||
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
|
||||
else:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
@@ -8,18 +8,20 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS
|
||||
jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba # type: ignore
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
|
||||
@@ -113,6 +113,8 @@ class BaiduVector(BaseVector):
|
||||
return False
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
quoted_ids = [f"'{id}'" for id in ids]
|
||||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")
|
||||
|
||||
|
||||
@@ -83,6 +83,8 @@ class ChromaVector(BaseVector):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(ids=ids)
|
||||
|
||||
|
||||
@@ -98,6 +98,8 @@ class ElasticSearchVector(BaseVector):
|
||||
return bool(self._client.exists(index=self._collection_name, id=id))
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
|
||||
|
||||
@@ -100,6 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
@@ -134,6 +134,8 @@ class OceanBaseVector(BaseVector):
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
|
||||
|
||||
@@ -167,6 +167,8 @@ class OracleVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user