mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 12:55:13 +00:00
Compare commits
5 Commits
verify-ema
...
llm-quota-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f41f624c50 | ||
|
|
9c61b9b325 | ||
|
|
0d9eb1583d | ||
|
|
e028e07953 | ||
|
|
27601fab44 |
@@ -29,6 +29,8 @@ ignore_imports =
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
@@ -107,14 +109,12 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
@@ -135,8 +135,8 @@ ignore_imports =
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
@@ -180,7 +180,7 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider_ids
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
|
||||
core.workflow.nodes.llm.node -> models.model
|
||||
core.workflow.workflow_entry -> models.enums
|
||||
core.workflow.nodes.agent.agent_node -> services
|
||||
|
||||
@@ -133,6 +133,7 @@ class EducationAutocompleteQuery(BaseModel):
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
|
||||
@@ -546,17 +547,13 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.token is not None:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@@ -576,7 +573,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=send_phase,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -611,26 +608,12 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
phase_transitions: dict[str, str] = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@@ -660,22 +643,13 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
"""LLM-related application services."""
|
||||
|
||||
from .quota import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]
|
||||
|
||||
93
api/core/app/llm/quota.py
Normal file
93
api/core/app/llm/quota.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(
|
||||
model_type=model_instance.model_type_instance.model_type,
|
||||
model=model_instance.model_name,
|
||||
)
|
||||
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
|
||||
|
||||
|
||||
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
|
||||
|
||||
from .llm_quota import LLMQuotaLayer
|
||||
from .observability import ObservabilityLayer
|
||||
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
|
||||
__all__ = [
|
||||
"LLMQuotaLayer",
|
||||
"ObservabilityLayer",
|
||||
"PersistenceWorkflowInfo",
|
||||
"WorkflowPersistenceLayer",
|
||||
|
||||
128
api/core/app/workflow/layers/llm_quota.py
Normal file
128
api/core/app/workflow/layers/llm_quota.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
LLM quota deduction layer for GraphEngine.
|
||||
|
||||
This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class LLMQuotaLayer(GraphEngineLayer):
|
||||
"""Graph layer that applies LLM quota deduction after node execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._abort_sent = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
_ = event
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
_ = error
|
||||
|
||||
@override
|
||||
def on_node_run_start(self, node: Node) -> None:
|
||||
if self._abort_sent:
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ensure_llm_quota_available(model_instance=model_instance)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
|
||||
|
||||
@override
|
||||
def on_node_run_end(
|
||||
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
|
||||
) -> None:
|
||||
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
|
||||
return
|
||||
|
||||
model_instance = self._extract_model_instance(node)
|
||||
if model_instance is None:
|
||||
return
|
||||
|
||||
try:
|
||||
deduct_llm_quota(
|
||||
tenant_id=node.tenant_id,
|
||||
model_instance=model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
except QuotaExceededError as exc:
|
||||
self._set_stop_event(node)
|
||||
self._send_abort_command(reason=str(exc))
|
||||
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
|
||||
except Exception:
|
||||
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
|
||||
|
||||
@staticmethod
|
||||
def _set_stop_event(node: Node) -> None:
|
||||
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
|
||||
def _send_abort_command(self, *, reason: str) -> None:
|
||||
if not self.command_channel or self._abort_sent:
|
||||
return
|
||||
|
||||
try:
|
||||
self.command_channel.send_command(
|
||||
AbortCommand(
|
||||
command_type=CommandType.ABORT,
|
||||
reason=reason,
|
||||
)
|
||||
)
|
||||
self._abort_sent = True
|
||||
except Exception:
|
||||
logger.exception("Failed to send quota abort command")
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_instance(node: Node) -> ModelInstance | None:
|
||||
try:
|
||||
match node.node_type:
|
||||
case NodeType.LLM:
|
||||
return cast("LLMNode", node).model_instance
|
||||
case NodeType.PARAMETER_EXTRACTOR:
|
||||
return cast("ParameterExtractorNode", node).model_instance
|
||||
case NodeType.QUESTION_CLASSIFIER:
|
||||
return cast("QuestionClassifierNode", node).model_instance
|
||||
case _:
|
||||
return None
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
|
||||
node.id,
|
||||
)
|
||||
return None
|
||||
@@ -2,6 +2,7 @@ import tempfile
|
||||
from binascii import hexlify, unhexlify
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
@@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunk, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
@@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
for chunk in response:
|
||||
if chunk.delta.usage:
|
||||
llm_utils.deduct_llm_quota(
|
||||
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
||||
)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
|
||||
chunk.prompt_messages = []
|
||||
yield chunk
|
||||
|
||||
return handle()
|
||||
else:
|
||||
if response.usage:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
||||
|
||||
def handle_non_streaming(
|
||||
response: LLMResultWithStructuredOutput,
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
# Deduct quota for summary generation (same as workflow nodes)
|
||||
try:
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
except Exception as e:
|
||||
# Log but don't fail summary generation if quota deduction fails
|
||||
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.llm import deduct_llm_quota
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
@@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.rag.retrieval.output_parser.react_output import ReactAction
|
||||
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
|
||||
|
||||
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
|
||||
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
|
||||
@@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
def _create_graph_engine(self, index: int, item: object):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
@@ -17,10 +14,7 @@ from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
@@ -68,68 +62,3 @@ def fetch_memory(
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
)
|
||||
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=used_quota,
|
||||
pool_type="paid",
|
||||
)
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@@ -278,8 +278,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
@@ -1234,6 +1232,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def retry(self) -> bool:
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
|
||||
@@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
|
||||
# Import dependencies
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
return graph_engine
|
||||
|
||||
@@ -308,9 +308,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(
|
||||
@@ -828,6 +825,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -240,6 +240,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, cast
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
@@ -106,6 +107,7 @@ class WorkflowEntry:
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
self.graph_engine.layer(LLMQuotaLayer())
|
||||
|
||||
# Add observability layer when OTel is enabled
|
||||
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -81,25 +80,12 @@ class TokenPair(BaseModel):
|
||||
csrf_token: str
|
||||
|
||||
|
||||
class ChangeEmailPhase(StrEnum):
|
||||
OLD = "old_email"
|
||||
OLD_VERIFIED = "old_email_verified"
|
||||
NEW = "new_email"
|
||||
NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
|
||||
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@@ -555,20 +541,13 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
@@ -53,7 +52,7 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_infer_new_email_phase_from_token(
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
@@ -69,16 +68,13 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
@@ -95,107 +91,6 @@ class TestChangeEmailSend:
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.db")
|
||||
@patch("controllers.console.workspace.account.Session")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account_by_email,
|
||||
mock_session_cls,
|
||||
mock_account_db,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
|
||||
existing_account = _build_account("old@example.com", "acc-old")
|
||||
mock_get_account_by_email.return_value = existing_account
|
||||
mock_send_email.return_value = "token-legacy"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__.return_value = mock_session
|
||||
mock_session_cm.__exit__.return_value = None
|
||||
mock_session_cls.return_value = mock_session_cm
|
||||
mock_account_db.engine = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-legacy"}
|
||||
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=existing_account,
|
||||
email="old@example.com",
|
||||
old_email="old@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_unverified_old_email_token_for_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -227,12 +122,7 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -248,76 +138,11 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
|
||||
},
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_refresh_new_email_phase_to_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("old@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "5678",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-phase-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("new@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-456")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="5678",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -350,11 +175,7 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "new@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@@ -373,106 +194,6 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_phase_token_for_reset(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_mismatched_new_email_for_verified_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "another@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-789"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import CommandType
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
def _build_succeeded_event() -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="execution-id",
|
||||
node_id="llm-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"question": "hello"},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_successful_llm_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_deduct_quota_called_for_question_classifier_node() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "question-classifier-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.QUESTION_CLASSIFIER
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_called_once_with(
|
||||
tenant_id="tenant-id",
|
||||
model_instance=node.model_instance,
|
||||
usage=result_event.node_run_result.llm_usage,
|
||||
)
|
||||
|
||||
|
||||
def test_non_llm_node_is_ignored() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "start-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.START
|
||||
node.tenant_id = "tenant-id"
|
||||
node._model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
mock_deduct.assert_not_called()
|
||||
|
||||
|
||||
def test_quota_error_is_handled_in_layer() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=ValueError("quota exceeded"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
|
||||
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.execution_id = "execution-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.tenant_id = "tenant-id"
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
result_event = _build_succeeded_event()
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("No credits remaining"),
|
||||
):
|
||||
layer.on_node_run_end(node=node, error=None, result_event=result_event)
|
||||
|
||||
assert stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "No credits remaining"
|
||||
|
||||
|
||||
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch(
|
||||
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
|
||||
autospec=True,
|
||||
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
|
||||
):
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert stop_event.is_set()
|
||||
layer.command_channel.send_command.assert_called_once()
|
||||
abort_command = layer.command_channel.send_command.call_args.args[0]
|
||||
assert abort_command.command_type == CommandType.ABORT
|
||||
assert abort_command.reason == "Model provider openai quota exceeded."
|
||||
|
||||
|
||||
def test_quota_precheck_passes_without_abort() -> None:
|
||||
layer = LLMQuotaLayer()
|
||||
stop_event = threading.Event()
|
||||
layer.command_channel = MagicMock()
|
||||
|
||||
node = MagicMock()
|
||||
node.id = "llm-node-id"
|
||||
node.node_type = NodeType.LLM
|
||||
node.model_instance = object()
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.stop_event = stop_event
|
||||
|
||||
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
|
||||
layer.on_node_run_start(node)
|
||||
|
||||
assert not stop_event.is_set()
|
||||
mock_check.assert_called_once_with(model_instance=node.model_instance)
|
||||
layer.command_channel.send_command.assert_not_called()
|
||||
@@ -58,10 +58,11 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, token?: string) => {
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
@@ -105,6 +106,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
@@ -160,6 +162,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
|
||||
@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
|
||||
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
|
||||
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
|
||||
|
||||
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
post<CommonResponse & { data: string }>('/account/change-email', { body })
|
||||
|
||||
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>
|
||||
|
||||
Reference in New Issue
Block a user