Compare commits

..

11 Commits

Author SHA1 Message Date
Wu Tianwei
3a90cd2d03 Merge branch 'main' into fix/plugin-tag-fallback 2025-09-16 17:24:56 +08:00
Jiang
b283b10d3e Fix/lindorm vdb optimize (#25748)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-16 16:54:18 +08:00
-LAN-
ecb22226d6 refactor: remove Claude-specific references from documentation files (#25760) 2025-09-16 14:22:14 +08:00
Xiyuan Chen
8635aacb46 Enhance LLM model configuration validation to include active status c… (#25759)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-15 23:15:53 -07:00
Asuka Minato
bdd85b36a4 ruff check preview (#25653)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-16 12:58:12 +08:00
znn
a0c7713494 chat remove transparency from chat bubble in dark mode (#24921) 2025-09-16 12:57:53 +08:00
NeatGuyCoding
abf4955c26 Feature: add test containers document indexing task (#25684)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2025-09-16 09:47:28 +08:00
miwa
74340e3c04 Bugfix: When i change the loop variable, 'Loop Termination Condition' wi… (#25695)
Co-authored-by: fengminhua <fengminhua@52tt.com>
2025-09-16 09:46:44 +08:00
-LAN-
b98b389baf fix(tests): resolve order dependency in disable_segments_from_index_task tests (#25737) 2025-09-16 08:26:52 +08:00
WTW0313
130c01ff6a feat: Add APP_VERSION to headers for marketplace API requests 2025-09-04 18:01:33 +08:00
WTW0313
a9b0b7a3b4 refactor: Simplify tag label retrieval in hooks and update related components 2025-09-04 17:32:15 +08:00
60 changed files with 1068 additions and 773 deletions

View File

@@ -22,7 +22,7 @@ jobs:
# Fix lint errors
uv run ruff check --fix .
# Format code
uv run ruff format .
uv run ruff format ..
- name: ast-grep
run: |

View File

@@ -1 +0,0 @@
CLAUDE.md

87
AGENTS.md Normal file
View File

@@ -0,0 +1,87 @@
# AGENTS.md
## Project Overview
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
The codebase consists of:
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Development Commands
### Backend (API)
All Python commands must be prefixed with `uv run --project api`:
```bash
# Start development servers
./dev/start-api # Start API server
./dev/start-worker # Start Celery worker
# Run tests
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
# Code quality
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)
```bash
cd web
pnpm lint # Run ESLint
pnpm eslint-fix # Fix ESLint issues
pnpm test # Run Jest tests
```
## Testing Guidelines
### Backend Testing
- Use `pytest` for all backend tests
- Write tests first (TDD approach)
- Test structure: Arrange-Act-Assert
## Code Style Requirements
### Python
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
### TypeScript/JavaScript
- Strict TypeScript configuration
- ESLint with Prettier integration
- Avoid `any` type
## Important Notes
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
- **Comments**: Only write meaningful comments that explain "why", not "what"
- **File Creation**: Always prefer editing existing files over creating new ones
- **Documentation**: Don't create documentation files unless explicitly requested
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
## Common Development Tasks
### Adding a New API Endpoint
1. Create controller in `/api/controllers/`
1. Add service logic in `/api/services/`
1. Update routes in controller's `__init__.py`
1. Write tests in `/api/tests/`
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@@ -1,89 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
The codebase consists of:
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Development Commands
### Backend (API)
All Python commands must be prefixed with `uv run --project api`:
```bash
# Start development servers
./dev/start-api # Start API server
./dev/start-worker # Start Celery worker
# Run tests
uv run --project api pytest # Run all tests
uv run --project api pytest tests/unit_tests/ # Unit tests only
uv run --project api pytest tests/integration_tests/ # Integration tests
# Code quality
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking
```
### Frontend (Web)
```bash
cd web
pnpm lint # Run ESLint
pnpm eslint-fix # Fix ESLint issues
pnpm test # Run Jest tests
```
## Testing Guidelines
### Backend Testing
- Use `pytest` for all backend tests
- Write tests first (TDD approach)
- Test structure: Arrange-Act-Assert
## Code Style Requirements
### Python
- Use type hints for all functions and class attributes
- No `Any` types unless absolutely necessary
- Implement special methods (`__repr__`, `__str__`) appropriately
### TypeScript/JavaScript
- Strict TypeScript configuration
- ESLint with Prettier integration
- Avoid `any` type
## Important Notes
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
- **Comments**: Only write meaningful comments that explain "why", not "what"
- **File Creation**: Always prefer editing existing files over creating new ones
- **Documentation**: Don't create documentation files unless explicitly requested
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
## Common Development Tasks
### Adding a New API Endpoint
1. Create controller in `/api/controllers/`
1. Add service logic in `/api/services/`
1. Update routes in controller's `__init__.py`
1. Write tests in `/api/tests/`
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

View File

@@ -328,7 +328,7 @@ MATRIXONE_DATABASE=dify
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration

View File

@@ -5,7 +5,7 @@ line-length = 120
quote-style = "double"
[lint]
preview = false
preview = true
select = [
"B", # flake8-bugbear rules
"C4", # flake8-comprehensions
@@ -45,6 +45,7 @@ select = [
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
]
ignore = [
@@ -64,6 +65,7 @@ ignore = [
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
"B901", # allow return in yield
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict

View File

@@ -1,6 +1,7 @@
import base64
import json
import logging
import operator
import secrets
from typing import Any
@@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool):
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
@@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables(
if dry_run:
logger.info("DRY RUN: Would delete the following:")
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
:10
]: # Show top 10
logger.info(" App %s: %s variables", app_id, count)

View File

@@ -19,15 +19,15 @@ class LindormConfig(BaseSettings):
description="Lindorm password",
default=None,
)
DEFAULT_INDEX_TYPE: str | None = Field(
LINDORM_INDEX_TYPE: str | None = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: str | None = Field(
LINDORM_DISTANCE_TYPE: str | None = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: bool | None = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
LINDORM_USING_UGC: bool | None = Field(
description="Using UGC index will store indexes with the same IndexType/Dimension in a single big index.",
default=True,
)
LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0)

View File

@@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance):
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),

View File

@@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
@@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
"status": status,
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
model_provider = process_data.get("model_provider", None)
model_name = process_data.get("model_name", None)
if model_provider is not None and model_name is not None:

View File

@@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
@@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm

View File

@@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
@@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
provider = None
model = None

View File

@@ -1,3 +1,4 @@
import collections
import json
import logging
import os
@@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks
logger = logging.getLogger(__name__)
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
@@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
provider_config_map = OpsTraceProviderConfigMap()
class OpsTraceManager:

View File

@@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance):
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs if node_execution.inputs else {}
outputs = node_execution.outputs if node_execution.outputs else {}
inputs = node_execution.inputs or {}
outputs = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
execution_metadata = node_execution.metadata or {}
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
@@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance):
}
)
process_data = node_execution.process_data if node_execution.process_data else {}
process_data = node_execution.process_data or {}
if process_data and process_data.get("model_mode") == "chat":
attributes.update(
{

View File

@@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector):
for doc, embedding in zip(batch_docs, batch_embeddings):
# Optimized: minimal checks for common case, fallback for edge cases
metadata = doc.metadata if doc.metadata else {}
metadata = doc.metadata or {}
if not isinstance(metadata, dict):
metadata = {}

View File

@@ -1,4 +1,3 @@
import copy
import json
import logging
import time
@@ -28,7 +27,7 @@ UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str
hosts: str | None
username: str | None = None
password: str | None = None
using_ugc: bool | None = False
@@ -46,7 +45,12 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {"hosts": self.hosts}
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@@ -54,18 +58,13 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs):
self._routing = None
self._routing_field = None
self._routing: str | None = None
if using_ugc:
routing_value: str | None = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = using_ugc
@@ -75,7 +74,8 @@ class LindormVectorStore(BaseVector):
return VectorType.LINDORM
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]), **kwargs)
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
def refresh(self):
@@ -120,7 +120,7 @@ class LindormVectorStore(BaseVector):
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_index": self.collection_name,
"_id": uuids[i],
}
}
@@ -131,14 +131,11 @@ class LindormVectorStore(BaseVector):
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
action_values[ROUTING_FIELD] = self._routing
actions.append(action_header)
actions.append(action_values)
# logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
try:
_bulk_with_retry(actions)
# logger.info(f"Successfully processed batch {batch_num + 1}")
@@ -155,7 +152,7 @@ class LindormVectorStore(BaseVector):
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@@ -216,7 +213,7 @@ class LindormVectorStore(BaseVector):
def delete(self):
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
"query": {"bool": {"must": [{"term": {f"{ROUTING_FIELD}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
@@ -229,7 +226,7 @@ class LindormVectorStore(BaseVector):
def text_exists(self, id: str) -> bool:
try:
params = {}
params: dict[str, Any] = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
@@ -244,20 +241,37 @@ class LindormVectorStore(BaseVector):
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 3)
document_ids_filter = kwargs.get("document_ids_filter")
filters = []
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
if self._using_ugc:
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
top_k = kwargs.get("top_k", 5)
search_query: dict[str, Any] = {
"size": top_k,
"_source": True,
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
}
final_ext: dict[str, Any] = {"lvector": {}}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
final_ext["lvector"]["filter_type"] = "pre_filter"
if final_ext != {"lvector": {}}:
search_query["ext"] = final_ext
try:
params = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing # type: ignore
response = self._client.search(index=self._collection_name, body=query, params=params)
response = self._client.search(index=self._collection_name, body=search_query, params=params)
except Exception:
logger.exception("Error executing vector search, query: %s", query)
logger.exception("Error executing vector search, query: %s", search_query)
raise
docs_and_scores = []
@@ -283,283 +297,85 @@ class LindormVectorStore(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
must = kwargs.get("must")
must_not = kwargs.get("must_not")
should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 3)
filters = kwargs.get("filter", [])
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
filters = []
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
text_field=Field.CONTENT_KEY.value,
must=must,
must_not=must_not,
should=should,
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
routing_field=self._routing_field,
)
params = {"timeout": self._client_config.request_timeout}
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
if self._using_ugc:
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
if filters:
full_text_query["query"]["bool"]["filter"] = filters
try:
params: dict[str, Any] = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
except Exception:
logger.exception("Error executing vector search, query: %s", full_text_query)
raise
docs = []
for hit in response["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
metadata = hit["_source"].get(Field.METADATA_KEY.value)
vector = hit["_source"].get(Field.VECTOR.value)
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(self, dimension: int, **kwargs):
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):
if not embeddings:
raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'")
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info("Collection %s already exists.", self._collection_name)
return
if self._client.indices.exists(index=self._collection_name):
logger.info("%s already exists.", self._collection_name.lower())
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 4)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
space_type=space_type,
shards=shards,
engine=engine,
data_type=data_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
nlist=nlist,
ivfpq_m=ivfpq_m,
centroids_use_hnsw=centroids_use_hnsw,
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any):
excludes_from_source = kwargs.get("excludes_from_source", False)
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
nlist = kwargs["nlist"]
centroids_use_hnsw = nlist > 10000
centroids_hnsw_m = 24
centroids_hnsw_ef_construct = 500
centroids_hnsw_ef_search = 100
parameters = {
"m": ivfpq_m,
"nlist": nlist,
"centroids_use_hnsw": centroids_use_hnsw,
"centroids_hnsw_m": centroids_hnsw_m,
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
}
elif method_name == "hnsw":
neighbor = kwargs["hnsw_m"]
ef_construction = kwargs["hnsw_ef_construction"]
parameters = {"m": neighbor, "ef_construction": ef_construction}
elif method_name == "flat":
parameters = {}
else:
raise RuntimeError(f"unexpected method_name: {method_name}")
mapping = {
"settings": {"index": {"number_of_shards": shard, "knn": True}},
"mappings": {
"properties": {
vector_field: {
"type": "knn_vector",
"dimension": dimension,
"data_type": data_type,
"method": {
"engine": engine,
"name": method_name,
"space_type": space_type,
"parameters": parameters,
if not self._client.indices.exists(index=self._collection_name):
index_body = {
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
"name": index_params.get("index_type", "hnsw")
if index_params
else dify_config.LINDORM_INDEX_TYPE,
"space_type": index_params.get("space_type", "l2")
if index_params
else dify_config.LINDORM_DISTANCE_TYPE,
"engine": "lvector",
},
},
}
},
},
text_field: {"type": "text", "analyzer": analyzer},
}
},
}
if excludes_from_source:
# e.g. {"excludes": ["vector_field"]}
mapping["mappings"]["_source"] = {"excludes": [vector_field]}
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
elif (using_ugc and method_name == "hnsw") or (using_ugc and method_name == "flat"):
mapping["settings"]["index"]["knn_routing"] = True
return mapping
def default_text_search_query(
query_text: str,
k: int = 4,
text_field: str = Field.CONTENT_KEY.value,
must: list[dict] | None = None,
must_not: list[dict] | None = None,
should: list[dict] | None = None,
minimum_should_match: int = 0,
filters: list[dict] | None = None,
routing: str | None = None,
routing_field: str | None = None,
**kwargs,
):
query_clause: dict[str, Any] = {}
if routing is not None:
query_clause = {
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
}
else:
query_clause = {"match": {text_field: query_text}}
# build the simplest search_query when only query_text is specified
if not must and not must_not and not should and not filters:
search_query = {"size": k, "query": query_clause}
return search_query
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
must = [query_clause]
boolean_query: dict[str, Any] = {"must": must}
if must_not:
if not isinstance(must_not, list):
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not
if should:
if not isinstance(should, list):
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match
if filters:
if not isinstance(filters, list):
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters
search_query = {"size": k, "query": {"bool": boolean_query}}
return search_query
def default_vector_search_query(
query_vector: list[float],
k: int = 4,
min_score: str = "0.0",
ef_search: str | None = None, # only for hnsw
nprobe: str | None = None, # "2000"
reorder_factor: str | None = None, # "20"
client_refactor: str | None = None, # "true"
vector_field: str = Field.VECTOR.value,
filters: list[dict] | None = None,
filter_type: str | None = None,
**kwargs,
):
if filters is not None:
filter_type = "pre_filter" if filter_type is None else filter_type
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext: dict[str, Any] = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
final_ext["lvector"]["ef_search"] = ef_search
if nprobe:
final_ext["lvector"]["nprobe"] = nprobe
if reorder_factor:
final_ext["lvector"]["reorder_factor"] = reorder_factor
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
search_query: dict[str, Any] = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
}
if filters is not None and len(filters) > 0:
# when using filter, transform filter from List[Dict] to Dict as valid format
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
if final_ext != {"lvector": {}}:
search_query["ext"] = final_ext
return search_query
}
logger.info("Creating Lindorm Search index %s", self._collection_name)
self._client.indices.create(index=self._collection_name, body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL or "",
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
using_ugc=dify_config.LINDORM_USING_UGC,
request_timeout=dify_config.LINDORM_QUERY_TIMEOUT,
)
using_ugc = dify_config.USING_UGC_INDEX
using_ugc = dify_config.LINDORM_USING_UGC
if using_ugc is None:
raise ValueError("USING_UGC_INDEX is not set")
raise ValueError("LINDORM_USING_UGC is not set")
routing_value = None
if dataset.index_struct:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
@@ -571,27 +387,27 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"index_type": dify_config.LINDORM_INDEX_TYPE,
"dimension": dimension,
"distance_type": distance_type,
"distance_type": dify_config.LINDORM_DISTANCE_TYPE,
"using_ugc": using_ugc,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
index_type = dify_config.LINDORM_INDEX_TYPE
distance_type = dify_config.LINDORM_DISTANCE_TYPE
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
routing_value = class_prefix.lower()
else:
index_name = class_prefix
index_name = class_prefix.lower()
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc)

View File

@@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector):
self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None
ids = []
for _, doc in enumerate(documents):
for doc in documents:
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id)

View File

@@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector):
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
if self._client_config.aws_service != "aoss":
action["_id"] = uuid4().hex
actions.append(action)

View File

@@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
else None
)
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.error = domain_model.error_message or None
db_model.total_tokens = domain_model.total_tokens
db_model.total_steps = domain_model.total_steps
db_model.exceptions_count = domain_model.exceptions_count

View File

@@ -320,7 +320,7 @@ class AgentNode(BaseNode):
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages

View File

@@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery:
imports.append("schedule.queue_monitor_task")
beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")

View File

@@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files.
import json
import logging
import operator
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import StrEnum, auto
@@ -356,7 +357,7 @@ class FileLifecycleManager:
# Cleanup old versions for each file
for base_filename, versions in file_versions.items():
# Sort by version number
versions.sort(key=lambda x: x[0], reverse=True)
versions.sort(key=operator.itemgetter(0), reverse=True)
# Keep the newest max_versions versions, delete the rest
if len(versions) > max_versions:

View File

@@ -375,13 +375,14 @@ class WorkflowService:
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
"""
Validate that an LLM model configuration can fetch valid credentials.
Validate that an LLM model configuration can fetch valid credentials and has active status.
This method attempts to get the model instance and validates that:
1. The provider exists and is configured
2. The model exists in the provider
3. Credentials can be fetched for the model
4. The credentials pass policy compliance checks
5. The model status is ACTIVE (not NO_CONFIGURE, DISABLED, etc.)
:param tenant_id: The tenant ID
:param provider: The provider name
@@ -391,6 +392,7 @@ class WorkflowService:
try:
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
# Get model instance to validate provider+model combination
model_manager = ModelManager()
@@ -402,6 +404,22 @@ class WorkflowService:
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
# If it fails, an exception will be raised
# Additionally, check the model status to ensure it's ACTIVE
provider_manager = ProviderManager()
provider_configurations = provider_manager.get_configurations(tenant_id)
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
target_model = None
for model in models:
if model.model == model_name and model.provider.provider == provider:
target_model = model
break
if target_model:
target_model.raise_for_status()
else:
raise ValueError(f"Model {model_name} not found for provider {provider}")
except Exception as e:
raise ValueError(
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"

View File

@@ -1,3 +1,4 @@
import operator
import traceback
import typing
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
def fix_only_checker(latest_version: str, current_version: str):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}

View File

@@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import pytest
@@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test download
with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f:
downloaded_content = f.read()
downloaded_content = Path(temp_file.name).read_bytes()
assert downloaded_content == test_content
# Test scan

View File

@@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
import uuid
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
@@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
db.session.commit()
# Test each unavailable document
for i, document in enumerate(test_cases):
for document in test_cases:
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
@@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(empty_csv_content)
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download
@@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage = mock_external_service_dependencies["storage"]
def mock_download(key, file_path):
with open(file_path, "w", encoding="utf-8") as f:
f.write(csv_content)
Path(file_path).write_text(csv_content, encoding="utf-8")
mock_storage.download.side_effect = mock_download

View File

@@ -362,7 +362,7 @@ class TestCleanDatasetTask:
# Create segments for each document
segments = []
for i, document in enumerate(documents):
for document in documents:
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
segments.append(segment)

View File

@@ -290,9 +290,9 @@ class TestDisableSegmentsFromIndexTask:
# Verify the call arguments (checking by attributes rather than object identity)
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # First argument should be the dataset
assert call_args[0][1] == [
segment.index_node_id for segment in segments
] # Second argument should be node IDs
assert sorted(call_args[0][1]) == sorted(
[segment.index_node_id for segment in segments]
) # Compare sorted lists to handle any order while preserving duplicates
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is False
@@ -719,7 +719,9 @@ class TestDisableSegmentsFromIndexTask:
# Verify the call arguments
call_args = mock_processor.clean.call_args
assert call_args[0][0].id == dataset.id # First argument should be the dataset
assert call_args[0][1] == expected_node_ids # Second argument should be node IDs
assert sorted(call_args[0][1]) == sorted(
expected_node_ids
) # Compare sorted lists to handle any order while preserving duplicates
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is False

View File

@@ -0,0 +1,554 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import document_indexing_task
class TestDocumentIndexingTask:
"""Integration tests for document_indexing_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
):
# Setup mock indexing runner
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
# Setup mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_feature_service.get_features.return_value = mock_features
yield {
"indexing_runner": mock_indexing_runner,
"indexing_runner_instance": mock_runner_instance,
"feature_service": mock_feature_service,
"features": mock_features,
}
def _create_test_dataset_and_documents(
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
):
"""
Helper method to create a test dataset and documents for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
document_count: Number of documents to create
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(document_count):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def _create_test_dataset_with_billing_features(
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
):
"""
Helper method to create a test dataset with billing features configured.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
billing_enabled: Whether billing is enabled
Returns:
tuple: (dataset, documents) - Created dataset and document instances
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER.value,
current=True,
)
db.session.add(join)
db.session.commit()
# Create dataset
dataset = Dataset(
id=fake.uuid4(),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
# Create documents
documents = []
for i in range(3):
document = Document(
id=fake.uuid4(),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
if billing_enabled:
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
mock_external_service_dependencies["features"].vector_space.limit = 100
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
return dataset, documents
def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful document indexing with multiple documents.
This test verifies:
- Proper dataset retrieval from database
- Correct document processing and status updates
- IndexingRunner integration
- Database state updates
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=3
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 3
def test_document_indexing_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
This test verifies:
- Proper error handling for missing datasets
- Early return without processing
- Database session cleanup
- No unnecessary indexing runner calls
"""
# Arrange: Use non-existent dataset ID
fake = Faker()
non_existent_dataset_id = fake.uuid4()
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
def test_document_indexing_task_document_not_found_in_dataset(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling when some documents don't exist in the dataset.
This test verifies:
- Only existing documents are processed
- Non-existent documents are ignored
- Indexing runner receives only valid documents
- Database state updates correctly
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
# Mix existing and non-existent document IDs
fake = Faker()
existing_document_ids = [doc.id for doc in documents]
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 2 # Only existing documents
def test_document_indexing_task_indexing_runner_exception(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of IndexingRunner exceptions.
This test verifies:
- Exceptions from IndexingRunner are properly caught
- Task completes without raising exceptions
- Database session is properly closed
- Error logging occurs
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
"Indexing runner failed"
)
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test processing documents with mixed initial states.
This test verifies:
- Documents with different initial states are handled correctly
- Only valid documents are processed
- Database state updates are consistent
- IndexingRunner receives correct documents
"""
# Arrange: Create test data
dataset, base_documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
# Create additional documents with different states
fake = Faker()
extra_documents = []
# Document with different indexing status
doc1 = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
doc2 = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
extra_documents.append(doc2)
db.session.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with mixed document states
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify all documents were updated to parsing status
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with all documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0] # First argument should be documents list
assert len(processed_documents) == 4
def test_document_indexing_task_billing_sandbox_plan_batch_limit(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test billing validation for sandbox plan batch upload limit.
This test verifies:
- Sandbox plan batch upload limit enforcement
- Error handling for batch upload limit exceeded
- Document status updates to error state
- Proper error message recording
"""
# Arrange: Create test data with billing enabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
)
# Configure sandbox plan with batch limit
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
# Create more documents than sandbox plan allows (limit is 1)
fake = Faker()
extra_documents = []
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
document = Document(
id=fake.uuid4(),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=i + 3,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=dataset.created_by,
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
extra_documents.append(document)
db.session.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "error"
assert document.error is not None
assert "batch upload" in document.error
assert document.stopped_at is not None
# Verify no indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_not_called()
def test_document_indexing_task_billing_disabled_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful processing when billing is disabled.
This test verifies:
- Processing continues normally when billing is disabled
- No billing validation occurs
- Documents are processed successfully
- IndexingRunner is called correctly
"""
# Arrange: Create test data with billing disabled
dataset, documents = self._create_test_dataset_with_billing_features(
db_session_with_containers, mock_external_service_dependencies, billing_enabled=False
)
document_ids = [doc.id for doc in documents]
# Act: Execute the task with billing disabled
document_indexing_task(dataset.id, document_ids)
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test handling of DocumentIsPausedError from IndexingRunner.
This test verifies:
- DocumentIsPausedError is properly caught and handled
- Task completes without raising exceptions
- Appropriate logging occurs
- Database session is properly closed
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
# Mock IndexingRunner to raise DocumentIsPausedError
from core.indexing_runner import DocumentIsPausedError
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
"Document indexing is paused"
)
# Act: Execute the task
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None

View File

@@ -15,7 +15,7 @@ class FakeResponse:
self.status_code = status_code
self.headers = headers or {}
self.content = content
self.text = text if text else content.decode("utf-8", errors="ignore")
self.text = text or content.decode("utf-8", errors="ignore")
# ---------------------------

View File

@@ -1,3 +1,4 @@
from pathlib import Path
from unittest.mock import Mock, create_autospec, patch
import pytest
@@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
# Console API create
console_create_file = "api/controllers/console/datasets/metadata.py"
if os.path.exists(console_create_file):
with open(console_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
content = Path(console_create_file).read_text()
# Should contain nullable=False, not nullable=True
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
# Service API create
service_create_file = "api/controllers/service_api/dataset/metadata.py"
if os.path.exists(service_create_file):
with open(service_create_file) as f:
content = f.read()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
content = Path(service_create_file).read_text()
# Should contain nullable=False, not nullable=True
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
assert "nullable=True" not in create_api_section
class TestMetadataValidationSummary:

View File

@@ -1,6 +1,7 @@
from pathlib import Path
import yaml # type: ignore
from dotenv import dotenv_values
from pathlib import Path
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
"APP_MAX_EXECUTION_TIME",
@@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f:
def test_yaml_config():
# python set == operator is used to compare two sets
DIFF_API_WITH_DOCKER = (
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
)
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
if DIFF_API_WITH_DOCKER:
print(
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
)
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
raise Exception("API and Docker config sets are different")
DIFF_API_WITH_DOCKER_COMPOSE = (
API_CONFIG_SET
- DOCKER_COMPOSE_CONFIG_SET
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER_COMPOSE:
print(
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
)
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
raise Exception("API and Docker Compose config sets are different")
print("All tests passed!")

View File

@@ -643,9 +643,10 @@ VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# Lindorm configuration, only available when VECTOR_STORE is `lindorm`
LINDORM_URL=http://lindorm:30070
LINDORM_USERNAME=lindorm
LINDORM_PASSWORD=lindorm
LINDORM_URL=http://localhost:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase`

View File

@@ -292,9 +292,10 @@ x-shared-env: &shared-api-worker-env
VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http}
VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30}
VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30}
LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070}
LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm}
LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm}
LINDORM_URL: ${LINDORM_URL:-http://localhost:30070}
LINDORM_USERNAME: ${LINDORM_USERNAME:-admin}
LINDORM_PASSWORD: ${LINDORM_PASSWORD:-admin}
LINDORM_USING_UGC: ${LINDORM_USING_UGC:-True}
LINDORM_QUERY_TIMEOUT: ${LINDORM_QUERY_TIMEOUT:-1}
OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase}
OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881}

View File

@@ -51,9 +51,7 @@ def cleanup() -> None:
if sys.stdin.isatty():
log.separator()
log.warning("This action cannot be undone!")
confirmation = input(
"Are you sure you want to remove all config and report files? (yes/no): "
)
confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ")
if confirmation.lower() not in ["yes", "y"]:
log.error("Cleanup cancelled.")

View File

@@ -3,4 +3,4 @@
from .config_helper import config_helper
from .logger_helper import Logger, ProgressLogger
__all__ = ["config_helper", "Logger", "ProgressLogger"]
__all__ = ["Logger", "ProgressLogger", "config_helper"]

View File

@@ -65,9 +65,9 @@ class ConfigHelper:
return None
try:
with open(config_path, "r") as f:
with open(config_path) as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {filename}: {e}")
return None
@@ -101,7 +101,7 @@ class ConfigHelper:
with open(config_path, "w") as f:
json.dump(data, f, indent=2)
return True
except IOError as e:
except OSError as e:
print(f"❌ Error writing {filename}: {e}")
return False
@@ -133,7 +133,7 @@ class ConfigHelper:
try:
config_path.unlink()
return True
except IOError as e:
except OSError as e:
print(f"❌ Error deleting {filename}: {e}")
return False
@@ -148,9 +148,9 @@ class ConfigHelper:
return None
try:
with open(state_path, "r") as f:
with open(state_path) as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
except (OSError, json.JSONDecodeError) as e:
print(f"❌ Error reading {self.state_file}: {e}")
return None
@@ -170,7 +170,7 @@ class ConfigHelper:
with open(state_path, "w") as f:
json.dump(data, f, indent=2)
return True
except IOError as e:
except OSError as e:
print(f"❌ Error writing {self.state_file}: {e}")
return False

View File

@@ -159,9 +159,7 @@ class ProgressLogger:
if self.logger.use_colors:
progress_bar = self._create_progress_bar()
print(
f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}"
)
print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}")
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
else:
print(f"\n[Step {self.current_step}/{self.total_steps}]")

View File

@@ -6,8 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
from common import config_helper
from common import Logger
from common import Logger, config_helper
def configure_openai_plugin() -> None:
@@ -72,29 +71,19 @@ def configure_openai_plugin() -> None:
if response.status_code == 200:
log.success("OpenAI plugin configured successfully!")
log.key_value(
"API Base", config_payload["credentials"]["openai_api_base"]
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
elif response.status_code == 201:
log.success("OpenAI plugin credentials created successfully!")
log.key_value(
"API Base", config_payload["credentials"]["openai_api_base"]
)
log.key_value(
"API Key", config_payload["credentials"]["openai_api_key"]
)
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
elif response.status_code == 401:
log.error("Configuration failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again")
else:
log.error(
f"Configuration failed with status code: {response.status_code}"
)
log.error(f"Configuration failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def create_api_key() -> None:
@@ -90,9 +90,7 @@ def create_api_key() -> None:
}
if config_helper.write_config("api_key_config", api_key_config):
log.info(
f"API key saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}")
else:
log.error("No API token received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
@@ -101,9 +99,7 @@ def create_api_key() -> None:
log.error("API key creation failed: Unauthorized")
log.info("Token may have expired. Please run login_admin.py again")
else:
log.error(
f"API key creation failed with status code: {response.status_code}"
)
log.error(f"API key creation failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def import_workflow_app() -> None:
@@ -30,7 +31,7 @@ def import_workflow_app() -> None:
log.error(f"DSL file not found: {dsl_path}")
return
with open(dsl_path, "r") as f:
with open(dsl_path) as f:
yaml_content = f.read()
log.step("Importing workflow app from DSL...")
@@ -86,9 +87,7 @@ def import_workflow_app() -> None:
log.success("Workflow app imported successfully!")
log.key_value("App ID", app_id)
log.key_value("App Mode", response_data.get("app_mode"))
log.key_value(
"DSL Version", response_data.get("imported_dsl_version")
)
log.key_value("DSL Version", response_data.get("imported_dsl_version"))
# Save app_id to config
app_config = {
@@ -99,9 +98,7 @@ def import_workflow_app() -> None:
}
if config_helper.write_config("app_config", app_config):
log.info(
f"App config saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}")
else:
log.error("Import completed but no app_id received")
log.debug(f"Response: {json.dumps(response_data, indent=2)}")

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import time
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def install_openai_plugin() -> None:
@@ -28,9 +28,7 @@ def install_openai_plugin() -> None:
# API endpoint for plugin installation
base_url = "http://localhost:5001"
install_endpoint = (
f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
)
install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
# Plugin identifier
plugin_payload = {
@@ -83,9 +81,7 @@ def install_openai_plugin() -> None:
log.info("Polling for task completion...")
# Poll for task completion
task_endpoint = (
f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
)
task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
attempt = 0
@@ -131,9 +127,7 @@ def install_openai_plugin() -> None:
plugins = task_info.get("plugins", [])
if plugins:
for plugin in plugins:
log.list_item(
f"{plugin.get('plugin_id')}: {plugin.get('message')}"
)
log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}")
break
# Continue polling if status is "pending" or other
@@ -149,9 +143,7 @@ def install_openai_plugin() -> None:
log.warning("Plugin may already be installed")
log.debug(f"Response: {response.text}")
else:
log.error(
f"Installation failed with status code: {response.status_code}"
)
log.error(f"Installation failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def login_admin() -> None:
@@ -77,16 +77,10 @@ def login_admin() -> None:
# Save token config
if config_helper.write_config("token_config", token_config):
log.info(
f"Token saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}")
# Show truncated token for verification
token_display = (
f"{access_token[:20]}..."
if len(access_token) > 20
else "Token saved"
)
token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved"
log.key_value("Access token", token_display)
elif response.status_code == 401:

View File

@@ -3,8 +3,10 @@
import json
import time
import uuid
from typing import Any, Iterator
from flask import Flask, request, jsonify, Response
from collections.abc import Iterator
from typing import Any
from flask import Flask, Response, jsonify, request
app = Flask(__name__)

View File

@@ -5,10 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper
from common import Logger
import httpx
from common import Logger, config_helper
def publish_workflow() -> None:
@@ -79,9 +79,7 @@ def publish_workflow() -> None:
try:
response_data = response.json()
if response_data:
log.debug(
f"Response: {json.dumps(response_data, indent=2)}"
)
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
except json.JSONDecodeError:
# Response might be empty or non-JSON
pass
@@ -93,9 +91,7 @@ def publish_workflow() -> None:
log.error("Workflow publish failed: App not found")
log.info("Make sure the app was imported successfully")
else:
log.error(
f"Workflow publish failed with status code: {response.status_code}"
)
log.error(f"Workflow publish failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -5,9 +5,10 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
import json
from common import config_helper, Logger
import httpx
from common import Logger, config_helper
def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
@@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
event = data.get("event")
if event == "workflow_started":
log.progress(
f"Workflow started: {data.get('data', {}).get('id')}"
)
log.progress(f"Workflow started: {data.get('data', {}).get('id')}")
elif event == "node_started":
node_data = data.get("data", {})
log.progress(
@@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
# Some lines might not be JSON
pass
else:
log.error(
f"Workflow run failed with status code: {response.status_code}"
)
log.error(f"Workflow run failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
else:
# Handle blocking response
@@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
log.info("📤 Final Answer:")
log.info(outputs.get("answer"), indent=2)
else:
log.error(
f"Workflow run failed with status code: {response.status_code}"
)
log.error(f"Workflow run failed with status code: {response.status_code}")
log.debug(f"Response: {response.text}")
except httpx.ConnectError:

View File

@@ -6,7 +6,7 @@ from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import httpx
from common import config_helper, Logger
from common import Logger, config_helper
def setup_admin_account() -> None:
@@ -24,9 +24,7 @@ def setup_admin_account() -> None:
# Save credentials to config file
if config_helper.write_config("admin_config", admin_config):
log.info(
f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}"
)
log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}")
# API setup endpoint
base_url = "http://localhost:5001"
@@ -56,9 +54,7 @@ def setup_admin_account() -> None:
log.key_value("Username", admin_config["username"])
elif response.status_code == 400:
log.warning(
"Setup may have already been completed or invalid data provided"
)
log.warning("Setup may have already been completed or invalid data provided")
log.debug(f"Response: {response.text}")
else:
log.error(f"Setup failed with status code: {response.status_code}")

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python3
import socket
import subprocess
import sys
import time
import socket
from pathlib import Path
from common import Logger, ProgressLogger
@@ -93,9 +93,7 @@ def main() -> None:
if retry.lower() in ["yes", "y"]:
return main() # Recursively call main to check again
else:
print(
"❌ Setup cancelled. Please start the required services and try again."
)
print("❌ Setup cancelled. Please start the required services and try again.")
sys.exit(1)
log.success("All required services are running!")

View File

@@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first
"""
import json
import time
import logging
import os
import random
import statistics
import sys
import threading
import os
import logging
import statistics
from pathlib import Path
import time
from collections import deque
from dataclasses import asdict, dataclass
from datetime import datetime
from dataclasses import dataclass, asdict
from locust import HttpUser, task, between, events, constant
from typing import TypedDict, Literal, TypeAlias
from pathlib import Path
from typing import Literal, TypeAlias, TypedDict
import requests.exceptions
from locust import HttpUser, between, constant, events, task
# Add the stress-test directory to path to import common modules
sys.path.insert(0, str(Path(__file__).parent))
from common.config_helper import ConfigHelper # type: ignore[import-not-found]
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Configuration from environment
@@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[
class ErrorCounts(TypedDict):
"""Error count tracking"""
connection_error: int
timeout: int
invalid_json: int
@@ -65,6 +65,7 @@ class ErrorCounts(TypedDict):
class SSEEvent(TypedDict):
"""Server-Sent Event structure"""
data: str
event: str
id: str | None
@@ -72,11 +73,13 @@ class SSEEvent(TypedDict):
class WorkflowInputs(TypedDict):
"""Workflow input structure"""
question: str
class WorkflowRequestData(TypedDict):
"""Workflow request payload"""
inputs: WorkflowInputs
response_mode: Literal["streaming"]
user: str
@@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict):
class ParsedEventData(TypedDict, total=False):
"""Parsed event data from SSE stream"""
event: str
task_id: str
workflow_run_id: str
@@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False):
class LocustStats(TypedDict):
"""Locust statistics structure"""
total_requests: int
total_failures: int
avg_response_time: float
@@ -102,6 +107,7 @@ class LocustStats(TypedDict):
class ReportData(TypedDict):
"""JSON report structure"""
timestamp: str
duration_seconds: float
metrics: dict[str, object] # Metrics as dict for JSON serialization
@@ -154,7 +160,7 @@ class MetricsTracker:
self.total_connections = 0
self.total_events = 0
self.start_time = time.time()
# Enhanced metrics with memory limits
self.max_samples = 10000 # Prevent unbounded growth
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
@@ -233,9 +239,7 @@ class MetricsTracker:
max_ttfe = max(self.ttfe_samples)
p50_ttfe = statistics.median(self.ttfe_samples)
if len(self.ttfe_samples) >= 2:
quantiles = statistics.quantiles(
self.ttfe_samples, n=20, method="inclusive"
)
quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive")
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
else:
p95_ttfe = max_ttfe
@@ -255,9 +259,7 @@ class MetricsTracker:
if durations
else 0
)
events_per_stream_avg = (
statistics.mean(events_per_stream) if events_per_stream else 0
)
events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0
# Calculate inter-event latency statistics
all_inter_event_times = []
@@ -268,32 +270,20 @@ class MetricsTracker:
inter_event_latency_avg = statistics.mean(all_inter_event_times)
inter_event_latency_p50 = statistics.median(all_inter_event_times)
inter_event_latency_p95 = (
statistics.quantiles(
all_inter_event_times, n=20, method="inclusive"
)[18]
statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18]
if len(all_inter_event_times) >= 2
else max(all_inter_event_times)
)
else:
inter_event_latency_avg = inter_event_latency_p50 = (
inter_event_latency_p95
) = 0
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
else:
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = (
events_per_stream_avg
) = 0
inter_event_latency_avg = inter_event_latency_p50 = (
inter_event_latency_p95
) = 0
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
# Also calculate overall average rates
total_elapsed = current_time - self.start_time
overall_conn_rate = (
self.total_connections / total_elapsed if total_elapsed > 0 else 0
)
overall_event_rate = (
self.total_events / total_elapsed if total_elapsed > 0 else 0
)
overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0
overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0
return MetricsSnapshot(
active_connections=self.active_connections,
@@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser):
# Load questions from file or use defaults
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
with open(QUESTIONS_FILE, "r") as f:
with open(QUESTIONS_FILE) as f:
self.questions = [line.strip() for line in f if line.strip()]
else:
self.questions = [
@@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser):
try:
# Validate response
if response.status_code >= 400:
error_type: ErrorType = (
"http_4xx" if response.status_code < 500 else "http_5xx"
)
error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx"
metrics.record_error(error_type)
response.failure(f"HTTP {response.status_code}")
return
content_type = response.headers.get("Content-Type", "")
if (
"text/event-stream" not in content_type
and "application/json" not in content_type
):
if "text/event-stream" not in content_type and "application/json" not in content_type:
logger.error(f"Expected text/event-stream, got: {content_type}")
metrics.record_error("invalid_response")
response.failure(f"Invalid content type: {content_type}")
@@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser):
for line in response.iter_lines(decode_unicode=True):
# Check if runner is stopping
if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'):
if getattr(self.environment.runner, "state", "") in (
"stopping",
"stopped",
):
logger.debug("Runner stopping, breaking streaming loop")
break
if line is not None:
bytes_received += len(line.encode("utf-8"))
@@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser):
# Track inter-event timing
if last_event_time:
inter_event_times.append(
(current_time - last_event_time) * 1000
)
inter_event_times.append((current_time - last_event_time) * 1000)
last_event_time = current_time
if first_event_time is None:
@@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser):
parsed_event: ParsedEventData = json.loads(event_data)
# Check for terminal events
if parsed_event.get("event") in TERMINAL_EVENTS:
logger.debug(
f"Received terminal event: {parsed_event.get('event')}"
)
logger.debug(f"Received terminal event: {parsed_event.get('event')}")
request_success = True
break
except json.JSONDecodeError as e:
logger.debug(
f"JSON decode error: {e} for data: {event_data[:100]}"
)
logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}")
metrics.record_error("invalid_json")
except Exception as e:
@@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None:
# Periodic stats reporting
def report_stats() -> None:
if not hasattr(environment, 'runner'):
if not hasattr(environment, "runner"):
return
runner = environment.runner
while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]:
while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]:
time.sleep(5) # Report every 5 seconds
if hasattr(runner, 'state') and runner.state == "running":
if hasattr(runner, "state") and runner.state == "running":
stats = metrics.get_stats()
# Only log on master node in distributed mode
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True
is_master = (
not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
)
if is_master:
# Clear previous lines and show updated stats
logger.info("\n" + "=" * 80)
@@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info(
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
)
logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)")
logger.info(
f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)"
)
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
# Inter-event latency
if stats.inter_event_latency_avg > 0:
logger.info("-" * 80)
logger.info(
f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}"
)
logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}")
logger.info(
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
)
@@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None:
logger.info("=" * 80)
# Show Locust stats summary
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0:
if hasattr(total, "num_requests") and total.num_requests > 0:
logger.info(
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
)
@@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("")
logger.info("EVENTS")
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
logger.info(
f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s"
)
logger.info(
f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s"
)
logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s")
logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s")
logger.info("")
logger.info("STREAM METRICS")
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
logger.info(
f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}"
)
logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}")
logger.info("")
logger.info("INTER-EVENT LATENCY")
@@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})")
logger.info(
f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})"
)
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
# Error summary
@@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
logger.info("=" * 80 + "\n")
# Export machine-readable report (only on master node)
is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
if is_master:
export_json_report(stats, test_duration, environment)
@@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj
# Access environment.stats.total attributes safely
locust_stats: LocustStats | None = None
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
total = environment.stats.total
if hasattr(total, 'num_requests') and total.num_requests > 0:
if hasattr(total, "num_requests") and total.num_requests > 0:
locust_stats = LocustStats(
total_requests=total.num_requests,
total_failures=total.num_failures,

View File

@@ -1,7 +1,15 @@
from dify_client.client import (
ChatClient,
CompletionClient,
WorkflowClient,
KnowledgeBaseClient,
DifyClient,
KnowledgeBaseClient,
WorkflowClient,
)
__all__ = [
"ChatClient",
"CompletionClient",
"DifyClient",
"KnowledgeBaseClient",
"WorkflowClient",
]

View File

@@ -8,16 +8,16 @@ class DifyClient:
self.api_key = api_key
self.base_url = base_url
def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False):
def _send_request(
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, json=json, params=params, headers=headers, stream=stream
)
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
return response
@@ -25,9 +25,7 @@ class DifyClient:
headers = {"Authorization": f"Bearer {self.api_key}"}
url = f"{self.base_url}{endpoint}"
response = requests.request(
method, url, data=data, headers=headers, files=files
)
response = requests.request(method, url, data=data, headers=headers, files=files)
return response
@@ -41,9 +39,7 @@ class DifyClient:
def file_upload(self, user: str, files: dict):
data = {"user": user}
return self._send_request_with_files(
"POST", "/files/upload", data=data, files=files
)
return self._send_request_with_files("POST", "/files/upload", data=data, files=files)
def text_to_audio(self, text: str, user: str, streaming: bool = False):
data = {"text": text, "user": user, "streaming": streaming}
@@ -55,7 +51,9 @@ class DifyClient:
class CompletionClient(DifyClient):
def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None):
def create_completion_message(
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
):
data = {
"inputs": inputs,
"response_mode": response_mode,
@@ -99,9 +97,7 @@ class ChatClient(DifyClient):
def get_suggested(self, message_id: str, user: str):
params = {"user": user}
return self._send_request(
"GET", f"/messages/{message_id}/suggested", params=params
)
return self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
def stop_message(self, task_id: str, user: str):
data = {"user": user}
@@ -112,10 +108,9 @@ class ChatClient(DifyClient):
user: str,
last_id: str | None = None,
limit: int | None = None,
pinned: bool | None = None
pinned: bool | None = None,
):
params = {"user": user, "last_id": last_id,
"limit": limit, "pinned": pinned}
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
return self._send_request("GET", "/conversations", params=params)
def get_conversation_messages(
@@ -123,7 +118,7 @@ class ChatClient(DifyClient):
user: str,
conversation_id: str | None = None,
first_id: str | None = None,
limit: int | None = None
limit: int | None = None,
):
params = {"user": user}
@@ -136,13 +131,9 @@ class ChatClient(DifyClient):
return self._send_request("GET", "/messages", params=params)
def rename_conversation(
self, conversation_id: str, name: str, auto_generate: bool, user: str
):
def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
data = {"name": name, "auto_generate": auto_generate, "user": user}
return self._send_request(
"POST", f"/conversations/{conversation_id}/name", data
)
return self._send_request("POST", f"/conversations/{conversation_id}/name", data)
def delete_conversation(self, conversation_id: str, user: str):
data = {"user": user}
@@ -155,9 +146,7 @@ class ChatClient(DifyClient):
class WorkflowClient(DifyClient):
def run(
self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"
):
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
return self._send_request("POST", "/workflows/run", data)
@@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient):
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
return self._send_request(
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
)
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
def create_document_by_text(
self, name, text, extra_params: dict | None = None, **kwargs
):
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
"""
Create a document by text.
@@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient):
data = {"name": name, "text": text}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
return self._send_request("POST", url, json=data, **kwargs)
def create_document_by_file(
@@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient):
if original_document_id is not None:
data["original_document_id"] = original_document_id
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def update_document_by_file(
self, document_id: str, file_path: str, extra_params: dict | None = None
):
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
"""
Update a document by file.
@@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient):
data = {}
if extra_params is not None and isinstance(extra_params, dict):
data.update(extra_params)
url = (
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
)
return self._send_request_with_files(
"POST", url, {"data": json.dumps(data)}, files
)
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
def batch_indexing_status(self, batch_id: str, **kwargs):
"""

View File

@@ -1,6 +1,6 @@
from setuptools import setup
with open("README.md", "r", encoding="utf-8") as fh:
with open("README.md", encoding="utf-8") as fh:
long_description = fh.read()
setup(

View File

@@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__)
class TestKnowledgeBaseClient(unittest.TestCase):
def setUp(self):
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
self.README_FILE_PATH = os.path.abspath(
os.path.join(FILE_PATH_BASE, "../README.md")
)
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
self.dataset_id = None
self.document_id = None
self.segment_id = None
@@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _get_dataset_kb_client(self):
self.assertIsNotNone(self.dataset_id)
return KnowledgeBaseClient(
API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id
)
return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
def test_001_create_dataset(self):
response = self.knowledge_base_client.create_dataset(name="test_dataset")
@@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_004_update_document_by_text(self):
client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id)
response = client.update_document_by_text(
self.document_id, "test_document_updated", "test_text_updated"
)
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
@@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_006_update_document_by_file(self):
client = self._get_dataset_kb_client()
self.assertIsNotNone(self.document_id)
response = client.update_document_by_file(
self.document_id, self.README_FILE_PATH
)
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
data = response.json()
self.assertIn("document", data)
self.assertIn("batch", data)
@@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
def _test_010_add_segments(self):
client = self._get_dataset_kb_client()
response = client.add_segments(
self.document_id, [{"content": "test text segment 1"}]
)
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
data = response.json()
self.assertIn("data", data)
self.assertGreater(len(data["data"]), 0)
@@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase):
self.chat_client = ChatClient(API_KEY)
def test_create_chat_message(self):
response = self.chat_client.create_chat_message(
{}, "Hello, World!", "test_user"
)
response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user")
self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_remote_url(self):
files = [
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
]
response = self.chat_client.create_chat_message(
{}, "Describe the picture.", "test_user", files=files
)
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
def test_create_chat_message_with_vision_model_by_local_file(self):
@@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase):
"upload_file_id": "your_file_id",
}
]
response = self.chat_client.create_chat_message(
{}, "Describe the picture.", "test_user", files=files
)
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
self.assertIn("answer", response.text)
def test_get_conversation_messages(self):
response = self.chat_client.get_conversation_messages(
"test_user", "your_conversation_id"
)
response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id")
self.assertIn("answer", response.text)
def test_get_conversations(self):
@@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase):
self.assertIn("answer", response.text)
def test_create_completion_message_with_vision_model_by_remote_url(self):
files = [
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
]
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
response = self.completion_client.create_completion_message(
{"query": "Describe the picture."}, "blocking", "test_user", files
)
@@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase):
self.dify_client = DifyClient(API_KEY)
def test_message_feedback(self):
response = self.dify_client.message_feedback(
"your_message_id", "like", "test_user"
)
response = self.dify_client.message_feedback("your_message_id", "like", "test_user")
self.assertIn("success", response.text)
def test_get_application_parameters(self):

View File

@@ -26,9 +26,16 @@ export const useTags = (translateFromOut?: TFunction) => {
return acc
}, {} as Record<string, Tag>)
const getTagLabel = (name: string) => {
if (!tagsMap[name])
return name
return tagsMap[name].label
}
return {
tags,
tagsMap,
getTagLabel,
}
}

View File

@@ -29,7 +29,7 @@ const CardWrapper = ({
setFalse: hideInstallFromMarketplace,
}] = useBoolean(false)
const { locale: localeFromLocale } = useI18N()
const { tagsMap } = useTags(t)
const { getTagLabel } = useTags(t)
if (showInstallButton) {
return (
@@ -43,7 +43,7 @@ const CardWrapper = ({
footer={
<CardMoreInfo
downloadCount={plugin.install_count}
tags={plugin.tags.map(tag => tagsMap[tag.name].label)}
tags={plugin.tags.map(tag => getTagLabel(tag.name))}
/>
}
/>
@@ -92,7 +92,7 @@ const CardWrapper = ({
footer={
<CardMoreInfo
downloadCount={plugin.install_count}
tags={plugin.tags.map(tag => tagsMap[tag.name].label)}
tags={plugin.tags.map(tag => getTagLabel(tag.name))}
/>
}
/>

View File

@@ -7,6 +7,7 @@ import type {
PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types'
import {
APP_VERSION,
MARKETPLACE_API_PREFIX,
} from '@/config'
import { getMarketplaceUrl } from '@/utils/var'
@@ -49,11 +50,15 @@ export const getMarketplacePluginsByCollectionId = async (collectionId: string,
try {
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
const headers = new Headers({
'X-Dify-Version': APP_VERSION,
})
const marketplaceCollectionPluginsData = await globalThis.fetch(
url,
{
cache: 'no-store',
method: 'POST',
headers,
body: JSON.stringify({
category: query?.category,
exclude: query?.exclude,
@@ -83,7 +88,10 @@ export const getMarketplaceCollectionsAndPlugins = async (query?: CollectionsAnd
marketplaceUrl += `&condition=${query.condition}`
if (query?.type)
marketplaceUrl += `&type=${query.type}`
const marketplaceCollectionsData = await globalThis.fetch(marketplaceUrl, { cache: 'no-store' })
const headers = new Headers({
'X-Dify-Version': APP_VERSION,
})
const marketplaceCollectionsData = await globalThis.fetch(marketplaceUrl, { headers, cache: 'no-store' })
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
marketplaceCollections = marketplaceCollectionsDataJson.data.collections
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {

View File

@@ -27,7 +27,7 @@ const TagsFilter = ({
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const [searchText, setSearchText] = useState('')
const { tags: options, tagsMap } = useTags()
const { tags: options, getTagLabel } = useTags()
const filteredOptions = options.filter(option => option.name.toLowerCase().includes(searchText.toLowerCase()))
const handleCheck = (id: string) => {
if (value.includes(id))
@@ -59,7 +59,7 @@ const TagsFilter = ({
!selectedTagsLength && t('pluginTags.allTags')
}
{
!!selectedTagsLength && value.map(val => tagsMap[val].label).slice(0, 2).join(',')
!!selectedTagsLength && value.map(val => getTagLabel(val)).slice(0, 2).join(',')
}
{
selectedTagsLength > 2 && (

View File

@@ -47,14 +47,14 @@ const useConfig = (id: string, payload: LoopNodeType) => {
})
const changeErrorResponseMode = useCallback((item: { value: unknown }) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
draft.error_handle_mode = item.value as ErrorHandleMode
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [inputs, handleInputsChange])
const handleAddCondition = useCallback<HandleAddCondition>((valueSelector, varItem) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
if (!draft.break_conditions)
draft.break_conditions = []
@@ -66,34 +66,34 @@ const useConfig = (id: string, payload: LoopNodeType) => {
value: varItem.type === VarType.boolean ? 'false' : '',
})
})
setInputs(newInputs)
}, [getIsVarFileAttribute, inputs, setInputs])
handleInputsChange(newInputs)
}, [getIsVarFileAttribute, handleInputsChange])
const handleRemoveCondition = useCallback<HandleRemoveCondition>((conditionId) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
draft.break_conditions = draft.break_conditions?.filter(item => item.id !== conditionId)
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleUpdateCondition = useCallback<HandleUpdateCondition>((conditionId, newCondition) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
if (targetCondition)
Object.assign(targetCondition, newCondition)
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleToggleConditionLogicalOperator = useCallback<HandleToggleConditionLogicalOperator>(() => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
draft.logical_operator = draft.logical_operator === LogicalOperator.and ? LogicalOperator.or : LogicalOperator.and
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleAddSubVariableCondition = useCallback<HandleAddSubVariableCondition>((conditionId: string, key?: string) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
const condition = draft.break_conditions?.find(item => item.id === conditionId)
if (!condition)
return
@@ -119,11 +119,11 @@ const useConfig = (id: string, payload: LoopNodeType) => {
})
}
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleRemoveSubVariableCondition = useCallback((conditionId: string, subConditionId: string) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
const condition = draft.break_conditions?.find(item => item.id === conditionId)
if (!condition)
return
@@ -133,11 +133,11 @@ const useConfig = (id: string, payload: LoopNodeType) => {
if (subVarCondition)
subVarCondition.conditions = subVarCondition.conditions.filter(item => item.id !== subConditionId)
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleUpdateSubVariableCondition = useCallback<HandleUpdateSubVariableCondition>((conditionId, subConditionId, newSubCondition) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
if (targetCondition && targetCondition.sub_variable_condition) {
const targetSubCondition = targetCondition.sub_variable_condition.conditions.find(item => item.id === subConditionId)
@@ -145,24 +145,24 @@ const useConfig = (id: string, payload: LoopNodeType) => {
Object.assign(targetSubCondition, newSubCondition)
}
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleToggleSubVariableConditionLogicalOperator = useCallback<HandleToggleSubVariableConditionLogicalOperator>((conditionId) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
if (targetCondition && targetCondition.sub_variable_condition)
targetCondition.sub_variable_condition.logical_operator = targetCondition.sub_variable_condition.logical_operator === LogicalOperator.and ? LogicalOperator.or : LogicalOperator.and
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleUpdateLoopCount = useCallback((value: number) => {
const newInputs = produce(inputs, (draft) => {
const newInputs = produce(inputsRef.current, (draft) => {
draft.loop_count = value
})
setInputs(newInputs)
}, [inputs, setInputs])
handleInputsChange(newInputs)
}, [handleInputsChange])
const handleAddLoopVariable = useCallback(() => {
const newInputs = produce(inputsRef.current, (draft) => {

View File

@@ -2,6 +2,7 @@ import { InputVarType } from '@/app/components/workflow/types'
import { AgentStrategy } from '@/types/app'
import { PromptRole } from '@/models/debug'
import { DatasetAttr } from '@/types/feature'
import pkg from '../package.json'
const getBooleanConfig = (envVar: string | undefined, dataAttrKey: DatasetAttr, defaultValue: boolean = true) => {
if (envVar !== undefined && envVar !== '')
@@ -294,3 +295,5 @@ export const ZENDESK_FIELD_IDS = {
WORKSPACE_ID: getStringConfig(process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID, DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID, ''),
PLAN: getStringConfig(process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN, DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN, ''),
}
export const APP_VERSION = pkg.version

View File

@@ -2,7 +2,7 @@ import type { AfterResponseHook, BeforeErrorHook, BeforeRequestHook, Hooks } fro
import ky from 'ky'
import type { IOtherOptions } from './base'
import Toast from '@/app/components/base/toast'
import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
import { API_PREFIX, APP_VERSION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils'
import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils'
@@ -151,6 +151,10 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
if (deleteContentType)
(headers as any).delete('Content-Type')
// ! For Marketplace API, help to filter tags added in new version
if (isMarketplaceAPI)
(headers as any).set('X-Dify-Version', APP_VERSION)
const client = baseClient.extend({
hooks: {
...baseHooks,

View File

@@ -3,8 +3,8 @@ html[data-theme="dark"] {
rgba(34, 34, 37, 0.9) 0%,
rgba(29, 29, 32, 0.9) 90.48%);
--color-chat-bubble-bg: linear-gradient(180deg,
rgba(200, 206, 218, 0.08) 0%,
rgba(200, 206, 218, 0.02) 100%);
rgb(42, 43, 48) 0%,
rgb(37, 38, 42) 100%);
--color-chat-input-mask: linear-gradient(180deg,
rgba(24, 24, 27, 0.04) 0%,
rgba(24, 24, 27, 0.60) 100%);