Feat/email register refactor (#25369)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
zyssyz123
2025-09-12 10:24:54 +08:00
committed by GitHub
parent bb1514be2d
commit c2fcd2895b
36 changed files with 2390 additions and 91 deletions

View File

@@ -37,7 +37,6 @@ from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError,
@@ -65,7 +64,11 @@ from tasks.mail_owner_transfer_task import (
send_old_owner_transfer_notify_email_task,
send_owner_transfer_confirm_task,
)
from tasks.mail_reset_password_task import send_reset_password_mail_task
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
from tasks.mail_reset_password_task import (
send_reset_password_mail_task,
send_reset_password_mail_task_when_account_not_exist,
)
logger = logging.getLogger(__name__)
@@ -82,8 +85,9 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
prefix="email_code_login_rate_limit", max_attempts=3, time_window=300 * 1
)
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
@@ -95,6 +99,7 @@ class AccountService:
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
EMAIL_REGISTER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@@ -171,7 +176,7 @@ class AccountService:
account = db.session.query(Account).filter_by(email=email).first()
if not account:
raise AccountNotFoundError()
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.BANNED.value:
raise AccountLoginError("Account is banned.")
@@ -296,7 +301,9 @@ class AccountService:
if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError
raise EmailCodeAccountDeletionRateLimitExceededError()
raise EmailCodeAccountDeletionRateLimitExceededError(
int(cls.email_code_account_deletion_rate_limiter.time_window / 60)
)
send_account_deletion_verification_code.delay(to=email, code=code)
@@ -435,6 +442,7 @@ class AccountService:
account: Optional[Account] = None,
email: Optional[str] = None,
language: str = "en-US",
is_allow_register: bool = False,
):
account_email = account.email if account else email
if account_email is None:
@@ -443,18 +451,59 @@ class AccountService:
if cls.reset_password_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import PasswordResetRateLimitExceededError
raise PasswordResetRateLimitExceededError()
raise PasswordResetRateLimitExceededError(int(cls.reset_password_rate_limiter.time_window / 60))
code, token = cls.generate_reset_password_token(account_email, account)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
if account:
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
else:
send_reset_password_mail_task_when_account_not_exist.delay(
language=language,
to=account_email,
is_allow_register=is_allow_register,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_email_register_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
language: str = "en-US",
):
account_email = account.email if account else email
if account_email is None:
raise ValueError("Email must be provided.")
if cls.email_register_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailRegisterRateLimitExceededError
raise EmailRegisterRateLimitExceededError(int(cls.email_register_rate_limiter.time_window / 60))
code, token = cls.generate_email_register_token(account_email)
if account:
send_email_register_mail_task_when_account_exist.delay(
language=language,
to=account_email,
account_name=account.name,
)
else:
send_email_register_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.email_register_rate_limiter.increment_rate_limit(account_email)
return token
@classmethod
def send_change_email_email(
cls,
@@ -473,7 +522,7 @@ class AccountService:
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError()
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
@@ -517,7 +566,7 @@ class AccountService:
if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import OwnerTransferRateLimitExceededError
raise OwnerTransferRateLimitExceededError()
raise OwnerTransferRateLimitExceededError(int(cls.owner_transfer_rate_limiter.time_window / 60))
code, token = cls.generate_owner_transfer_token(account_email, account)
workspace_name = workspace_name or ""
@@ -587,6 +636,19 @@ class AccountService:
)
return code, token
@classmethod
def generate_email_register_token(
cls,
email: str,
code: Optional[str] = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(email=email, token_type="email_register", additional_data=additional_data)
return code, token
@classmethod
def generate_change_email_token(
cls,
@@ -625,6 +687,10 @@ class AccountService:
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
@classmethod
def revoke_email_register_token(cls, token: str):
TokenManager.revoke_token(token, "email_register")
@classmethod
def revoke_change_email_token(cls, token: str):
TokenManager.revoke_token(token, "change_email")
@@ -637,6 +703,10 @@ class AccountService:
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "change_email")
@@ -658,7 +728,7 @@ class AccountService:
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
raise EmailCodeLoginRateLimitExceededError(int(cls.email_code_login_rate_limiter.time_window / 60))
code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
token = TokenManager.generate_token(
@@ -744,6 +814,16 @@ class AccountService:
count = int(count) + 1
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=None)
def add_email_register_error_rate_limit(email: str) -> None:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, dify_config.EMAIL_REGISTER_LOCKOUT_DURATION, count)
@staticmethod
@redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
@@ -763,6 +843,24 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=False)
def is_email_register_error_rate_limit(email: str) -> bool:
key = f"email_register_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.EMAIL_REGISTER_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
@redis_fallback(default_return=None)
def reset_email_register_error_rate_limit(email: str):
key = f"email_register_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str):