mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
chore: adopt StrEnum and auto() for some string-typed enums (#25129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
|
||||
@@ -22,37 +21,37 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
OTHER = "other"
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = auto()
|
||||
IMAGE = auto()
|
||||
VIDEOS = auto()
|
||||
WEATHER = auto()
|
||||
FINANCE = auto()
|
||||
DESIGN = auto()
|
||||
TRAVEL = auto()
|
||||
SOCIAL = auto()
|
||||
NEWS = auto()
|
||||
MEDICAL = auto()
|
||||
PRODUCTIVITY = auto()
|
||||
EDUCATION = auto()
|
||||
BUSINESS = auto()
|
||||
ENTERTAINMENT = auto()
|
||||
UTILITIES = auto()
|
||||
OTHER = auto()
|
||||
|
||||
|
||||
class ToolProviderType(enum.StrEnum):
|
||||
class ToolProviderType(StrEnum):
|
||||
"""
|
||||
Enum class for tool provider
|
||||
"""
|
||||
|
||||
PLUGIN = "plugin"
|
||||
PLUGIN = auto()
|
||||
BUILT_IN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
API = "api"
|
||||
APP = "app"
|
||||
WORKFLOW = auto()
|
||||
API = auto()
|
||||
APP = auto()
|
||||
DATASET_RETRIEVAL = "dataset-retrieval"
|
||||
MCP = "mcp"
|
||||
MCP = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ToolProviderType":
|
||||
@@ -68,15 +67,15 @@ class ToolProviderType(enum.StrEnum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderSchemaType(Enum):
|
||||
class ApiProviderSchemaType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider schema type.
|
||||
"""
|
||||
|
||||
OPENAPI = "openapi"
|
||||
SWAGGER = "swagger"
|
||||
OPENAI_PLUGIN = "openai_plugin"
|
||||
OPENAI_ACTIONS = "openai_actions"
|
||||
OPENAPI = auto()
|
||||
SWAGGER = auto()
|
||||
OPENAI_PLUGIN = auto()
|
||||
OPENAI_ACTIONS = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderSchemaType":
|
||||
@@ -92,14 +91,14 @@ class ApiProviderSchemaType(Enum):
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class ApiProviderAuthType(Enum):
|
||||
class ApiProviderAuthType(StrEnum):
|
||||
"""
|
||||
Enum class for api provider auth type.
|
||||
"""
|
||||
|
||||
NONE = "none"
|
||||
API_KEY_HEADER = "api_key_header"
|
||||
API_KEY_QUERY = "api_key_query"
|
||||
NONE = auto()
|
||||
API_KEY_HEADER = auto()
|
||||
API_KEY_QUERY = auto()
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ApiProviderAuthType":
|
||||
@@ -176,10 +175,10 @@ class ToolInvokeMessage(BaseModel):
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
class LogStatus(StrEnum):
|
||||
START = auto()
|
||||
ERROR = auto()
|
||||
SUCCESS = auto()
|
||||
|
||||
id: str
|
||||
label: str = Field(..., description="The label of the log")
|
||||
@@ -193,19 +192,19 @@ class ToolInvokeMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
LINK = "link"
|
||||
BLOB = "blob"
|
||||
JSON = "json"
|
||||
IMAGE_LINK = "image_link"
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
BLOB_CHUNK = "blob_chunk"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
class MessageType(StrEnum):
|
||||
TEXT = auto()
|
||||
IMAGE = auto()
|
||||
LINK = auto()
|
||||
BLOB = auto()
|
||||
JSON = auto()
|
||||
IMAGE_LINK = auto()
|
||||
BINARY_LINK = auto()
|
||||
VARIABLE = auto()
|
||||
FILE = auto()
|
||||
LOG = auto()
|
||||
BLOB_CHUNK = auto()
|
||||
RETRIEVER_RESOURCES = auto()
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
@@ -250,29 +249,29 @@ class ToolParameter(PluginParameter):
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
class ToolParameterType(StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
ANY = PluginParameterType.ANY.value
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
|
||||
STRING = PluginParameterType.STRING
|
||||
NUMBER = PluginParameterType.NUMBER
|
||||
BOOLEAN = PluginParameterType.BOOLEAN
|
||||
SELECT = PluginParameterType.SELECT
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT
|
||||
FILE = PluginParameterType.FILE
|
||||
FILES = PluginParameterType.FILES
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR
|
||||
ANY = PluginParameterType.ANY
|
||||
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT
|
||||
|
||||
# MCP object and array type parameters
|
||||
ARRAY = MCPServerParameterType.ARRAY.value
|
||||
OBJECT = MCPServerParameterType.OBJECT.value
|
||||
ARRAY = MCPServerParameterType.ARRAY
|
||||
OBJECT = MCPServerParameterType.OBJECT
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
@@ -280,10 +279,10 @@ class ToolParameter(PluginParameter):
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
class ToolParameterForm(StrEnum):
|
||||
SCHEMA = auto() # should be set while adding tool
|
||||
FORM = auto() # should be set before invoking tool
|
||||
LLM = auto() # will be set by LLM
|
||||
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
@@ -446,14 +445,14 @@ class ToolLabel(BaseModel):
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class ToolInvokeFrom(Enum):
|
||||
class ToolInvokeFrom(StrEnum):
|
||||
"""
|
||||
Enum class for tool invoke
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
WORKFLOW = auto()
|
||||
AGENT = auto()
|
||||
PLUGIN = auto()
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
@@ -478,9 +477,9 @@ class ToolSelector(BaseModel):
|
||||
return self.model_dump()
|
||||
|
||||
|
||||
class CredentialType(enum.StrEnum):
|
||||
class CredentialType(StrEnum):
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
OAUTH2 = auto()
|
||||
|
||||
def get_name(self):
|
||||
if self == CredentialType.API_KEY:
|
||||
|
||||
Reference in New Issue
Block a user