mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 17:54:01 +00:00
Compare commits
11 Commits
purity
...
fix/plugin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a90cd2d03 | ||
|
|
b283b10d3e | ||
|
|
ecb22226d6 | ||
|
|
8635aacb46 | ||
|
|
bdd85b36a4 | ||
|
|
a0c7713494 | ||
|
|
abf4955c26 | ||
|
|
74340e3c04 | ||
|
|
b98b389baf | ||
|
|
130c01ff6a | ||
|
|
a9b0b7a3b4 |
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
uv run ruff format ..
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
|
||||
87
AGENTS.md
Normal file
87
AGENTS.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
89
CLAUDE.md
89
CLAUDE.md
@@ -1,89 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
@@ -328,7 +328,7 @@ MATRIXONE_DATABASE=dify
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
|
||||
@@ -5,7 +5,7 @@ line-length = 120
|
||||
quote-style = "double"
|
||||
|
||||
[lint]
|
||||
preview = false
|
||||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
@@ -45,6 +45,7 @@ select = [
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
]
|
||||
|
||||
ignore = [
|
||||
@@ -64,6 +65,7 @@ ignore = [
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B901", # allow return in yield
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
import secrets
|
||||
from typing import Any
|
||||
|
||||
@@ -953,7 +954,7 @@ def clear_orphaned_file_records(force: bool):
|
||||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
@@ -1307,7 +1308,7 @@ def cleanup_orphaned_draft_variables(
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=operator.itemgetter(1), reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables", app_id, count)
|
||||
|
||||
@@ -19,15 +19,15 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: str | None = Field(
|
||||
LINDORM_INDEX_TYPE: str | None = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: str | None = Field(
|
||||
LINDORM_DISTANCE_TYPE: str | None = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: bool | None = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
LINDORM_USING_UGC: bool | None = Field(
|
||||
description="Using UGC index will store indexes with the same IndexType/Dimension in a single big index.",
|
||||
default=True,
|
||||
)
|
||||
LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0)
|
||||
|
||||
@@ -355,8 +355,8 @@ class AliyunDataTrace(BaseTraceInstance):
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs or {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
|
||||
@@ -144,13 +144,13 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
@@ -163,7 +163,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
model_provider = process_data.get("model_provider", None)
|
||||
model_name = process_data.get("model_name", None)
|
||||
if model_provider is not None and model_name is not None:
|
||||
|
||||
@@ -167,13 +167,13 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
metadata = {str(key): value for key, value in execution_metadata.items()}
|
||||
metadata.update(
|
||||
@@ -188,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = LangSmithRunType.llm
|
||||
|
||||
@@ -182,13 +182,13 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
metadata.update(
|
||||
{
|
||||
@@ -202,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -40,7 +41,7 @@ from tasks.ops_trace_task import process_trace_tasks
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
@@ -121,7 +122,7 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
|
||||
|
||||
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
|
||||
class OpsTraceManager:
|
||||
|
||||
@@ -169,13 +169,13 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
if node_type == NodeType.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs if node_execution.inputs else {}
|
||||
outputs = node_execution.outputs if node_execution.outputs else {}
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata if node_execution.metadata else {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
@@ -190,7 +190,7 @@ class WeaveDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
|
||||
@@ -641,7 +641,7 @@ class ClickzettaVector(BaseVector):
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata if doc.metadata else {}
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -28,7 +27,7 @@ UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str
|
||||
hosts: str | None
|
||||
username: str | None = None
|
||||
password: str | None = None
|
||||
using_ugc: bool | None = False
|
||||
@@ -46,7 +45,12 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {"hosts": self.hosts}
|
||||
params: dict[str, Any] = {
|
||||
"hosts": self.hosts,
|
||||
"use_ssl": False,
|
||||
"pool_maxsize": 128,
|
||||
"timeout": 30,
|
||||
}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
@@ -54,18 +58,13 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
|
||||
class LindormVectorStore(BaseVector):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using_ugc: bool, **kwargs):
|
||||
self._routing = None
|
||||
self._routing_field = None
|
||||
self._routing: str | None = None
|
||||
if using_ugc:
|
||||
routing_value: str | None = kwargs.get("routing_value")
|
||||
if routing_value is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
self._routing = routing_value.lower()
|
||||
self._routing_field = ROUTING_FIELD
|
||||
ugc_index_name = collection_name
|
||||
super().__init__(ugc_index_name.lower())
|
||||
else:
|
||||
super().__init__(collection_name.lower())
|
||||
super().__init__(collection_name.lower())
|
||||
self._client_config = config
|
||||
self._client = OpenSearch(**config.to_opensearch_params())
|
||||
self._using_ugc = using_ugc
|
||||
@@ -75,7 +74,8 @@ class LindormVectorStore(BaseVector):
|
||||
return VectorType.LINDORM
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.create_collection(len(embeddings[0]), **kwargs)
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def refresh(self):
|
||||
@@ -120,7 +120,7 @@ class LindormVectorStore(BaseVector):
|
||||
for i in range(start_idx, end_idx):
|
||||
action_header = {
|
||||
"index": {
|
||||
"_index": self.collection_name.lower(),
|
||||
"_index": self.collection_name,
|
||||
"_id": uuids[i],
|
||||
}
|
||||
}
|
||||
@@ -131,14 +131,11 @@ class LindormVectorStore(BaseVector):
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
action_values[ROUTING_FIELD] = self._routing
|
||||
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
|
||||
# logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
|
||||
|
||||
try:
|
||||
_bulk_with_retry(actions)
|
||||
# logger.info(f"Successfully processed batch {batch_num + 1}")
|
||||
@@ -155,7 +152,7 @@ class LindormVectorStore(BaseVector):
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
|
||||
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
@@ -216,7 +213,7 @@ class LindormVectorStore(BaseVector):
|
||||
def delete(self):
|
||||
if self._using_ugc:
|
||||
routing_filter_query = {
|
||||
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
|
||||
"query": {"bool": {"must": [{"term": {f"{ROUTING_FIELD}.keyword": self._routing}}]}}
|
||||
}
|
||||
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
|
||||
self.refresh()
|
||||
@@ -229,7 +226,7 @@ class LindormVectorStore(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
params = {}
|
||||
params: dict[str, Any] = {}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
self._client.get(index=self._collection_name, id=id, params=params)
|
||||
@@ -244,20 +241,37 @@ class LindormVectorStore(BaseVector):
|
||||
if not all(isinstance(x, float) for x in query_vector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 3)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filters = []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||
if self._using_ugc:
|
||||
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
search_query: dict[str, Any] = {
|
||||
"size": top_k,
|
||||
"_source": True,
|
||||
"query": {"knn": {Field.VECTOR.value: {"vector": query_vector, "k": top_k}}},
|
||||
}
|
||||
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if filters is not None and len(filters) > 0:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][Field.VECTOR.value]["filter"] = filter_dict # filter should be Dict
|
||||
final_ext["lvector"]["filter_type"] = "pre_filter"
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
search_query["ext"] = final_ext
|
||||
|
||||
try:
|
||||
params = {"timeout": self._client_config.request_timeout}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing # type: ignore
|
||||
response = self._client.search(index=self._collection_name, body=query, params=params)
|
||||
response = self._client.search(index=self._collection_name, body=search_query, params=params)
|
||||
except Exception:
|
||||
logger.exception("Error executing vector search, query: %s", query)
|
||||
logger.exception("Error executing vector search, query: %s", search_query)
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
@@ -283,283 +297,85 @@ class LindormVectorStore(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
must = kwargs.get("must")
|
||||
must_not = kwargs.get("must_not")
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 3)
|
||||
filters = kwargs.get("filter", [])
|
||||
full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}}
|
||||
filters = []
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
if document_ids_filter:
|
||||
filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
|
||||
routing = self._routing
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
k=top_k,
|
||||
text_field=Field.CONTENT_KEY.value,
|
||||
must=must,
|
||||
must_not=must_not,
|
||||
should=should,
|
||||
minimum_should_match=minimum_should_match,
|
||||
filters=filters,
|
||||
routing=routing,
|
||||
routing_field=self._routing_field,
|
||||
)
|
||||
params = {"timeout": self._client_config.request_timeout}
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
|
||||
if self._using_ugc:
|
||||
filters.append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
|
||||
if filters:
|
||||
full_text_query["query"]["bool"]["filter"] = filters
|
||||
|
||||
try:
|
||||
params: dict[str, Any] = {"timeout": self._client_config.request_timeout}
|
||||
if self._using_ugc:
|
||||
params["routing"] = self._routing
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
|
||||
except Exception:
|
||||
logger.exception("Error executing vector search, query: %s", full_text_query)
|
||||
raise
|
||||
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
metadata = hit["_source"].get(Field.METADATA_KEY.value)
|
||||
vector = hit["_source"].get(Field.VECTOR.value)
|
||||
page_content = hit["_source"].get(Field.CONTENT_KEY.value)
|
||||
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(self, dimension: int, **kwargs):
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
|
||||
):
|
||||
if not embeddings:
|
||||
raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'")
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
return
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
logger.info("%s already exists.", self._collection_name.lower())
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
return
|
||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||
self.kwargs = copy.deepcopy(kwargs)
|
||||
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
||||
shards = kwargs.pop("shards", 4)
|
||||
|
||||
engine = kwargs.pop("engine", "lvector")
|
||||
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
|
||||
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
|
||||
data_type = kwargs.pop("data_type", "float")
|
||||
|
||||
hnsw_m = kwargs.pop("hnsw_m", 24)
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
mapping = default_text_mapping(
|
||||
dimension,
|
||||
method_name,
|
||||
space_type=space_type,
|
||||
shards=shards,
|
||||
engine=engine,
|
||||
data_type=data_type,
|
||||
vector_field=vector_field,
|
||||
hnsw_m=hnsw_m,
|
||||
hnsw_ef_construction=hnsw_ef_construction,
|
||||
nlist=nlist,
|
||||
ivfpq_m=ivfpq_m,
|
||||
centroids_use_hnsw=centroids_use_hnsw,
|
||||
centroids_hnsw_m=centroids_hnsw_m,
|
||||
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
||||
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
||||
using_ugc=self._using_ugc,
|
||||
**kwargs,
|
||||
)
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
# logger.info(f"create index success: {self._collection_name}")
|
||||
|
||||
|
||||
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any):
|
||||
excludes_from_source = kwargs.get("excludes_from_source", False)
|
||||
analyzer = kwargs.get("analyzer", "ik_max_word")
|
||||
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
||||
engine = kwargs["engine"]
|
||||
shard = kwargs["shards"]
|
||||
space_type = kwargs.get("space_type")
|
||||
if space_type is None:
|
||||
if method_name == "hnsw":
|
||||
space_type = "l2"
|
||||
else:
|
||||
space_type = "cosine"
|
||||
data_type = kwargs["data_type"]
|
||||
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
||||
using_ugc = kwargs.get("using_ugc", False)
|
||||
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = nlist > 10000
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
parameters = {
|
||||
"m": ivfpq_m,
|
||||
"nlist": nlist,
|
||||
"centroids_use_hnsw": centroids_use_hnsw,
|
||||
"centroids_hnsw_m": centroids_hnsw_m,
|
||||
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
|
||||
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
|
||||
}
|
||||
elif method_name == "hnsw":
|
||||
neighbor = kwargs["hnsw_m"]
|
||||
ef_construction = kwargs["hnsw_ef_construction"]
|
||||
parameters = {"m": neighbor, "ef_construction": ef_construction}
|
||||
elif method_name == "flat":
|
||||
parameters = {}
|
||||
else:
|
||||
raise RuntimeError(f"unexpected method_name: {method_name}")
|
||||
|
||||
mapping = {
|
||||
"settings": {"index": {"number_of_shards": shard, "knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
vector_field: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension,
|
||||
"data_type": data_type,
|
||||
"method": {
|
||||
"engine": engine,
|
||||
"name": method_name,
|
||||
"space_type": space_type,
|
||||
"parameters": parameters,
|
||||
if not self._client.indices.exists(index=self._collection_name):
|
||||
index_body = {
|
||||
"settings": {"index": {"knn": True, "knn_routing": self._using_ugc}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
Field.CONTENT_KEY.value: {"type": "text"},
|
||||
Field.VECTOR.value: {
|
||||
"type": "knn_vector",
|
||||
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
|
||||
"method": {
|
||||
"name": index_params.get("index_type", "hnsw")
|
||||
if index_params
|
||||
else dify_config.LINDORM_INDEX_TYPE,
|
||||
"space_type": index_params.get("space_type", "l2")
|
||||
if index_params
|
||||
else dify_config.LINDORM_DISTANCE_TYPE,
|
||||
"engine": "lvector",
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
text_field: {"type": "text", "analyzer": analyzer},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if excludes_from_source:
|
||||
# e.g. {"excludes": ["vector_field"]}
|
||||
mapping["mappings"]["_source"] = {"excludes": [vector_field]}
|
||||
|
||||
if using_ugc and method_name == "ivfpq":
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
mapping["settings"]["index"]["knn.offline.construction"] = True
|
||||
elif (using_ugc and method_name == "hnsw") or (using_ugc and method_name == "flat"):
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
return mapping
|
||||
|
||||
|
||||
def default_text_search_query(
|
||||
query_text: str,
|
||||
k: int = 4,
|
||||
text_field: str = Field.CONTENT_KEY.value,
|
||||
must: list[dict] | None = None,
|
||||
must_not: list[dict] | None = None,
|
||||
should: list[dict] | None = None,
|
||||
minimum_should_match: int = 0,
|
||||
filters: list[dict] | None = None,
|
||||
routing: str | None = None,
|
||||
routing_field: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
query_clause: dict[str, Any] = {}
|
||||
if routing is not None:
|
||||
query_clause = {
|
||||
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
|
||||
}
|
||||
else:
|
||||
query_clause = {"match": {text_field: query_text}}
|
||||
# build the simplest search_query when only query_text is specified
|
||||
if not must and not must_not and not should and not filters:
|
||||
search_query = {"size": k, "query": query_clause}
|
||||
return search_query
|
||||
|
||||
# build complex search_query when either of must/must_not/should/filter is specified
|
||||
if must:
|
||||
if not isinstance(must, list):
|
||||
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
|
||||
if query_clause not in must:
|
||||
must.append(query_clause)
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query: dict[str, Any] = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
|
||||
boolean_query["must_not"] = must_not
|
||||
|
||||
if should:
|
||||
if not isinstance(should, list):
|
||||
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
|
||||
boolean_query["should"] = should
|
||||
if minimum_should_match != 0:
|
||||
boolean_query["minimum_should_match"] = minimum_should_match
|
||||
|
||||
if filters:
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
|
||||
boolean_query["filter"] = filters
|
||||
|
||||
search_query = {"size": k, "query": {"bool": boolean_query}}
|
||||
return search_query
|
||||
|
||||
|
||||
def default_vector_search_query(
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
min_score: str = "0.0",
|
||||
ef_search: str | None = None, # only for hnsw
|
||||
nprobe: str | None = None, # "2000"
|
||||
reorder_factor: str | None = None, # "20"
|
||||
client_refactor: str | None = None, # "true"
|
||||
vector_field: str = Field.VECTOR.value,
|
||||
filters: list[dict] | None = None,
|
||||
filter_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if filters is not None:
|
||||
filter_type = "pre_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
final_ext["lvector"]["ef_search"] = ef_search
|
||||
if nprobe:
|
||||
final_ext["lvector"]["nprobe"] = nprobe
|
||||
if reorder_factor:
|
||||
final_ext["lvector"]["reorder_factor"] = reorder_factor
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query: dict[str, Any] = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
}
|
||||
|
||||
if filters is not None and len(filters) > 0:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
search_query["ext"] = final_ext
|
||||
return search_query
|
||||
}
|
||||
logger.info("Creating Lindorm Search index %s", self._collection_name)
|
||||
self._client.indices.create(index=self._collection_name, body=index_body)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL or "",
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
using_ugc=dify_config.USING_UGC_INDEX,
|
||||
using_ugc=dify_config.LINDORM_USING_UGC,
|
||||
request_timeout=dify_config.LINDORM_QUERY_TIMEOUT,
|
||||
)
|
||||
using_ugc = dify_config.USING_UGC_INDEX
|
||||
using_ugc = dify_config.LINDORM_USING_UGC
|
||||
if using_ugc is None:
|
||||
raise ValueError("USING_UGC_INDEX is not set")
|
||||
raise ValueError("LINDORM_USING_UGC is not set")
|
||||
routing_value = None
|
||||
if dataset.index_struct:
|
||||
# if an existed record's index_struct_dict doesn't contain using_ugc field,
|
||||
@@ -571,27 +387,27 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
index_type = dataset.index_struct_dict["index_type"]
|
||||
distance_type = dataset.index_struct_dict["distance_type"]
|
||||
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
|
||||
else:
|
||||
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
|
||||
else:
|
||||
embedding_vector = embeddings.embed_query("hello word")
|
||||
dimension = len(embedding_vector)
|
||||
index_type = dify_config.DEFAULT_INDEX_TYPE
|
||||
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
|
||||
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.LINDORM,
|
||||
"vector_store": {"class_prefix": class_prefix},
|
||||
"index_type": index_type,
|
||||
"index_type": dify_config.LINDORM_INDEX_TYPE,
|
||||
"dimension": dimension,
|
||||
"distance_type": distance_type,
|
||||
"distance_type": dify_config.LINDORM_DISTANCE_TYPE,
|
||||
"using_ugc": using_ugc,
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
if using_ugc:
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
|
||||
routing_value = class_prefix
|
||||
index_type = dify_config.LINDORM_INDEX_TYPE
|
||||
distance_type = dify_config.LINDORM_DISTANCE_TYPE
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
|
||||
routing_value = class_prefix.lower()
|
||||
else:
|
||||
index_name = class_prefix
|
||||
index_name = class_prefix.lower()
|
||||
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value, using_ugc=using_ugc)
|
||||
|
||||
@@ -103,7 +103,7 @@ class MatrixoneVector(BaseVector):
|
||||
self.client = self._get_client(len(embeddings[0]), True)
|
||||
assert self.client is not None
|
||||
ids = []
|
||||
for _, doc in enumerate(documents):
|
||||
for doc in documents:
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
ids.append(doc_id)
|
||||
|
||||
@@ -104,7 +104,7 @@ class OpenSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
if self._client_config.aws_service not in ["aoss"]:
|
||||
if self._client_config.aws_service != "aoss":
|
||||
action["_id"] = uuid4().hex
|
||||
actions.append(action)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
else None
|
||||
)
|
||||
db_model.status = domain_model.status
|
||||
db_model.error = domain_model.error_message if domain_model.error_message else None
|
||||
db_model.error = domain_model.error_message or None
|
||||
db_model.total_tokens = domain_model.total_tokens
|
||||
db_model.total_steps = domain_model.total_steps
|
||||
db_model.exceptions_count = domain_model.exceptions_count
|
||||
|
||||
@@ -320,7 +320,7 @@ class AgentNode(BaseNode):
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size if node_data.memory.window.size else None
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
|
||||
@@ -141,9 +141,7 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports.append("schedule.queue_monitor_task")
|
||||
beat_schedule["datasets-queue-monitor"] = {
|
||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(
|
||||
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
|
||||
),
|
||||
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
|
||||
@@ -7,6 +7,7 @@ Supports complete lifecycle management for knowledge base files.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
@@ -356,7 +357,7 @@ class FileLifecycleManager:
|
||||
# Cleanup old versions for each file
|
||||
for base_filename, versions in file_versions.items():
|
||||
# Sort by version number
|
||||
versions.sort(key=lambda x: x[0], reverse=True)
|
||||
versions.sort(key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
# Keep the newest max_versions versions, delete the rest
|
||||
if len(versions) > max_versions:
|
||||
|
||||
@@ -375,13 +375,14 @@ class WorkflowService:
|
||||
|
||||
def _validate_llm_model_config(self, tenant_id: str, provider: str, model_name: str) -> None:
|
||||
"""
|
||||
Validate that an LLM model configuration can fetch valid credentials.
|
||||
Validate that an LLM model configuration can fetch valid credentials and has active status.
|
||||
|
||||
This method attempts to get the model instance and validates that:
|
||||
1. The provider exists and is configured
|
||||
2. The model exists in the provider
|
||||
3. Credentials can be fetched for the model
|
||||
4. The credentials pass policy compliance checks
|
||||
5. The model status is ACTIVE (not NO_CONFIGURE, DISABLED, etc.)
|
||||
|
||||
:param tenant_id: The tenant ID
|
||||
:param provider: The provider name
|
||||
@@ -391,6 +392,7 @@ class WorkflowService:
|
||||
try:
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
# Get model instance to validate provider+model combination
|
||||
model_manager = ModelManager()
|
||||
@@ -402,6 +404,22 @@ class WorkflowService:
|
||||
# via ProviderConfiguration.get_current_credentials() -> _check_credential_policy_compliance()
|
||||
# If it fails, an exception will be raised
|
||||
|
||||
# Additionally, check the model status to ensure it's ACTIVE
|
||||
provider_manager = ProviderManager()
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
|
||||
|
||||
target_model = None
|
||||
for model in models:
|
||||
if model.model == model_name and model.provider.provider == provider:
|
||||
target_model = model
|
||||
break
|
||||
|
||||
if target_model:
|
||||
target_model.raise_for_status()
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} not found for provider {provider}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to validate LLM model configuration (provider: {provider}, model: {model_name}): {str(e)}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import operator
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
current_version = version
|
||||
latest_version = manifest.latest_version
|
||||
|
||||
def fix_only_checker(latest_version, current_version):
|
||||
def fix_only_checker(latest_version: str, current_version: str):
|
||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||
|
||||
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
return False
|
||||
|
||||
version_checker = {
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
|
||||
current_version: latest_version != current_version,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
|
||||
@@ -12,6 +12,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -276,8 +277,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(csv_content)
|
||||
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
@@ -505,7 +505,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
db.session.commit()
|
||||
|
||||
# Test each unavailable document
|
||||
for i, document in enumerate(test_cases):
|
||||
for document in test_cases:
|
||||
job_id = str(uuid.uuid4())
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
@@ -601,8 +601,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(empty_csv_content)
|
||||
Path(file_path).write_text(empty_csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
@@ -684,8 +683,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
mock_storage = mock_external_service_dependencies["storage"]
|
||||
|
||||
def mock_download(key, file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(csv_content)
|
||||
Path(file_path).write_text(csv_content, encoding="utf-8")
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
|
||||
@@ -362,7 +362,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
# Create segments for each document
|
||||
segments = []
|
||||
for i, document in enumerate(documents):
|
||||
for document in documents:
|
||||
segment = self._create_test_segment(db_session_with_containers, account, tenant, dataset, document)
|
||||
segments.append(segment)
|
||||
|
||||
|
||||
@@ -290,9 +290,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Verify the call arguments (checking by attributes rather than object identity)
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # First argument should be the dataset
|
||||
assert call_args[0][1] == [
|
||||
segment.index_node_id for segment in segments
|
||||
] # Second argument should be node IDs
|
||||
assert sorted(call_args[0][1]) == sorted(
|
||||
[segment.index_node_id for segment in segments]
|
||||
) # Compare sorted lists to handle any order while preserving duplicates
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is False
|
||||
|
||||
@@ -719,7 +719,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Verify the call arguments
|
||||
call_args = mock_processor.clean.call_args
|
||||
assert call_args[0][0].id == dataset.id # First argument should be the dataset
|
||||
assert call_args[0][1] == expected_node_ids # Second argument should be node IDs
|
||||
assert sorted(call_args[0][1]) == sorted(
|
||||
expected_node_ids
|
||||
) # Compare sorted lists to handle any order while preserving duplicates
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is False
|
||||
|
||||
|
||||
@@ -0,0 +1,554 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document
|
||||
from tasks.document_indexing_task import document_indexing_task
|
||||
|
||||
|
||||
class TestDocumentIndexingTask:
|
||||
"""Integration tests for document_indexing_task using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
|
||||
):
|
||||
# Setup mock indexing runner
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
# Setup mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
yield {
|
||||
"indexing_runner": mock_indexing_runner,
|
||||
"indexing_runner_instance": mock_runner_instance,
|
||||
"feature_service": mock_feature_service,
|
||||
"features": mock_features,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset and documents for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
document_count: Number of documents to create
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(document_count):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def _create_test_dataset_with_billing_features(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset with billing features configured.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
billing_enabled: Whether billing is enabled
|
||||
|
||||
Returns:
|
||||
tuple: (dataset, documents) - Created dataset and document instances
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER.value,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
|
||||
# Create documents
|
||||
documents = []
|
||||
for i in range(3):
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Configure billing features
|
||||
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
|
||||
if billing_enabled:
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
mock_external_service_dependencies["features"].vector_space.limit = 100
|
||||
mock_external_service_dependencies["features"].vector_space.size = 50
|
||||
|
||||
# Refresh dataset to ensure it's properly loaded
|
||||
db.session.refresh(dataset)
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful document indexing with multiple documents.
|
||||
|
||||
This test verifies:
|
||||
- Proper dataset retrieval from database
|
||||
- Correct document processing and status updates
|
||||
- IndexingRunner integration
|
||||
- Database state updates
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=3
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify indexing runner was called correctly
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with correct documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 3
|
||||
|
||||
def test_document_indexing_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent dataset.
|
||||
|
||||
This test verifies:
|
||||
- Proper error handling for missing datasets
|
||||
- Early return without processing
|
||||
- Database session cleanup
|
||||
- No unnecessary indexing runner calls
|
||||
"""
|
||||
# Arrange: Use non-existent dataset ID
|
||||
fake = Faker()
|
||||
non_existent_dataset_id = fake.uuid4()
|
||||
document_ids = [fake.uuid4() for _ in range(3)]
|
||||
|
||||
# Act: Execute the task with non-existent dataset
|
||||
document_indexing_task(non_existent_dataset_id, document_ids)
|
||||
|
||||
# Assert: Verify no processing occurred
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
|
||||
def test_document_indexing_task_document_not_found_in_dataset(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when some documents don't exist in the dataset.
|
||||
|
||||
This test verifies:
|
||||
- Only existing documents are processed
|
||||
- Non-existent documents are ignored
|
||||
- Indexing runner receives only valid documents
|
||||
- Database state updates correctly
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
|
||||
# Mix existing and non-existent document IDs
|
||||
fake = Faker()
|
||||
existing_document_ids = [doc.id for doc in documents]
|
||||
non_existent_document_ids = [fake.uuid4() for _ in range(2)]
|
||||
all_document_ids = existing_document_ids + non_existent_document_ids
|
||||
|
||||
# Act: Execute the task with mixed document IDs
|
||||
document_indexing_task(dataset.id, all_document_ids)
|
||||
|
||||
# Assert: Verify only existing documents were processed
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only existing documents were updated
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with only existing documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 2 # Only existing documents
|
||||
|
||||
def test_document_indexing_task_indexing_runner_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of IndexingRunner exceptions.
|
||||
|
||||
This test verifies:
|
||||
- Exceptions from IndexingRunner are properly caught
|
||||
- Task completes without raising exceptions
|
||||
- Database session is properly closed
|
||||
- Error logging occurs
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock IndexingRunner to raise an exception
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception(
|
||||
"Indexing runner failed"
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_mixed_document_states(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test processing documents with mixed initial states.
|
||||
|
||||
This test verifies:
|
||||
- Documents with different initial states are handled correctly
|
||||
- Only valid documents are processed
|
||||
- Database state updates are consistent
|
||||
- IndexingRunner receives correct documents
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, base_documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
|
||||
# Create additional documents with different states
|
||||
fake = Faker()
|
||||
extra_documents = []
|
||||
|
||||
# Document with different indexing status
|
||||
doc1 = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=2,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="completed", # Already completed
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(doc1)
|
||||
extra_documents.append(doc1)
|
||||
|
||||
# Document with disabled status
|
||||
doc2 = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=3,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="waiting",
|
||||
enabled=False, # Disabled
|
||||
)
|
||||
db.session.add(doc2)
|
||||
extra_documents.append(doc2)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
all_documents = base_documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with mixed document states
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify all documents were updated to parsing status
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
# Verify the run method was called with all documents
|
||||
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
|
||||
assert call_args is not None
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 4
|
||||
|
||||
def test_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test billing validation for sandbox plan batch upload limit.
|
||||
|
||||
This test verifies:
|
||||
- Sandbox plan batch upload limit enforcement
|
||||
- Error handling for batch upload limit exceeded
|
||||
- Document status updates to error state
|
||||
- Proper error message recording
|
||||
"""
|
||||
# Arrange: Create test data with billing enabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=True
|
||||
)
|
||||
|
||||
# Configure sandbox plan with batch limit
|
||||
mock_external_service_dependencies["features"].billing.subscription.plan = "sandbox"
|
||||
|
||||
# Create more documents than sandbox plan allows (limit is 1)
|
||||
fake = Faker()
|
||||
extra_documents = []
|
||||
for i in range(2): # Total will be 5 documents (3 existing + 2 new)
|
||||
document = Document(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=i + 3,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=dataset.created_by,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(document)
|
||||
extra_documents.append(document)
|
||||
|
||||
db.session.commit()
|
||||
all_documents = documents + extra_documents
|
||||
document_ids = [doc.id for doc in all_documents]
|
||||
|
||||
# Act: Execute the task with too many documents for sandbox plan
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify error handling
|
||||
for document in all_documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
assert "batch upload" in document.error
|
||||
assert document.stopped_at is not None
|
||||
|
||||
# Verify no indexing runner was called
|
||||
mock_external_service_dependencies["indexing_runner"].assert_not_called()
|
||||
|
||||
def test_document_indexing_task_billing_disabled_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful processing when billing is disabled.
|
||||
|
||||
This test verifies:
|
||||
- Processing continues normally when billing is disabled
|
||||
- No billing validation occurs
|
||||
- Documents are processed successfully
|
||||
- IndexingRunner is called correctly
|
||||
"""
|
||||
# Arrange: Create test data with billing disabled
|
||||
dataset, documents = self._create_test_dataset_with_billing_features(
|
||||
db_session_with_containers, mock_external_service_dependencies, billing_enabled=False
|
||||
)
|
||||
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Act: Execute the task with billing disabled
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify successful processing
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were updated to parsing status
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
|
||||
def test_document_indexing_task_document_is_paused_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of DocumentIsPausedError from IndexingRunner.
|
||||
|
||||
This test verifies:
|
||||
- DocumentIsPausedError is properly caught and handled
|
||||
- Task completes without raising exceptions
|
||||
- Appropriate logging occurs
|
||||
- Database session is properly closed
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
|
||||
# Mock IndexingRunner to raise DocumentIsPausedError
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
|
||||
"Document indexing is paused"
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
document_indexing_task(dataset.id, document_ids)
|
||||
|
||||
# Assert: Verify exception was handled gracefully
|
||||
# The task should complete without raising exceptions
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify documents were still updated to parsing status before the exception
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
assert document.indexing_status == "parsing"
|
||||
assert document.processing_started_at is not None
|
||||
@@ -15,7 +15,7 @@ class FakeResponse:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.content = content
|
||||
self.text = text if text else content.decode("utf-8", errors="ignore")
|
||||
self.text = text or content.decode("utf-8", errors="ignore")
|
||||
|
||||
|
||||
# ---------------------------
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
@@ -146,19 +147,17 @@ class TestMetadataBugCompleteValidation:
|
||||
# Console API create
|
||||
console_create_file = "api/controllers/console/datasets/metadata.py"
|
||||
if os.path.exists(console_create_file):
|
||||
with open(console_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
content = Path(console_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
assert "nullable=True" not in content.split("class DatasetMetadataCreateApi")[1].split("class")[0]
|
||||
|
||||
# Service API create
|
||||
service_create_file = "api/controllers/service_api/dataset/metadata.py"
|
||||
if os.path.exists(service_create_file):
|
||||
with open(service_create_file) as f:
|
||||
content = f.read()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
content = Path(service_create_file).read_text()
|
||||
# Should contain nullable=False, not nullable=True
|
||||
create_api_section = content.split("class DatasetMetadataCreateServiceApi")[1].split("class")[0]
|
||||
assert "nullable=True" not in create_api_section
|
||||
|
||||
|
||||
class TestMetadataValidationSummary:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml # type: ignore
|
||||
from dotenv import dotenv_values
|
||||
from pathlib import Path
|
||||
|
||||
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
|
||||
"APP_MAX_EXECUTION_TIME",
|
||||
@@ -98,23 +99,15 @@ with open(Path("docker") / Path("docker-compose.yaml")) as f:
|
||||
|
||||
def test_yaml_config():
|
||||
# python set == operator is used to compare two sets
|
||||
DIFF_API_WITH_DOCKER = (
|
||||
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||
)
|
||||
DIFF_API_WITH_DOCKER = API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||
if DIFF_API_WITH_DOCKER:
|
||||
print(
|
||||
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
|
||||
)
|
||||
print(f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}")
|
||||
raise Exception("API and Docker config sets are different")
|
||||
DIFF_API_WITH_DOCKER_COMPOSE = (
|
||||
API_CONFIG_SET
|
||||
- DOCKER_COMPOSE_CONFIG_SET
|
||||
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||
API_CONFIG_SET - DOCKER_COMPOSE_CONFIG_SET - BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||
)
|
||||
if DIFF_API_WITH_DOCKER_COMPOSE:
|
||||
print(
|
||||
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
|
||||
)
|
||||
print(f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}")
|
||||
raise Exception("API and Docker Compose config sets are different")
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
@@ -643,9 +643,10 @@ VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Lindorm configuration, only available when VECTOR_STORE is `lindorm`
|
||||
LINDORM_URL=http://lindorm:30070
|
||||
LINDORM_USERNAME=lindorm
|
||||
LINDORM_PASSWORD=lindorm
|
||||
LINDORM_URL=http://localhost:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase`
|
||||
|
||||
@@ -292,9 +292,10 @@ x-shared-env: &shared-api-worker-env
|
||||
VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http}
|
||||
VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30}
|
||||
VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30}
|
||||
LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070}
|
||||
LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm}
|
||||
LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm}
|
||||
LINDORM_URL: ${LINDORM_URL:-http://localhost:30070}
|
||||
LINDORM_USERNAME: ${LINDORM_USERNAME:-admin}
|
||||
LINDORM_PASSWORD: ${LINDORM_PASSWORD:-admin}
|
||||
LINDORM_USING_UGC: ${LINDORM_USING_UGC:-True}
|
||||
LINDORM_QUERY_TIMEOUT: ${LINDORM_QUERY_TIMEOUT:-1}
|
||||
OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase}
|
||||
OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881}
|
||||
|
||||
@@ -51,9 +51,7 @@ def cleanup() -> None:
|
||||
if sys.stdin.isatty():
|
||||
log.separator()
|
||||
log.warning("This action cannot be undone!")
|
||||
confirmation = input(
|
||||
"Are you sure you want to remove all config and report files? (yes/no): "
|
||||
)
|
||||
confirmation = input("Are you sure you want to remove all config and report files? (yes/no): ")
|
||||
|
||||
if confirmation.lower() not in ["yes", "y"]:
|
||||
log.error("Cleanup cancelled.")
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
from .config_helper import config_helper
|
||||
from .logger_helper import Logger, ProgressLogger
|
||||
|
||||
__all__ = ["config_helper", "Logger", "ProgressLogger"]
|
||||
__all__ = ["Logger", "ProgressLogger", "config_helper"]
|
||||
|
||||
@@ -65,9 +65,9 @@ class ConfigHelper:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"❌ Error reading {filename}: {e}")
|
||||
return None
|
||||
|
||||
@@ -101,7 +101,7 @@ class ConfigHelper:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
return True
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
print(f"❌ Error writing {filename}: {e}")
|
||||
return False
|
||||
|
||||
@@ -133,7 +133,7 @@ class ConfigHelper:
|
||||
try:
|
||||
config_path.unlink()
|
||||
return True
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
print(f"❌ Error deleting {filename}: {e}")
|
||||
return False
|
||||
|
||||
@@ -148,9 +148,9 @@ class ConfigHelper:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(state_path, "r") as f:
|
||||
with open(state_path) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"❌ Error reading {self.state_file}: {e}")
|
||||
return None
|
||||
|
||||
@@ -170,7 +170,7 @@ class ConfigHelper:
|
||||
with open(state_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
return True
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
print(f"❌ Error writing {self.state_file}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -159,9 +159,7 @@ class ProgressLogger:
|
||||
|
||||
if self.logger.use_colors:
|
||||
progress_bar = self._create_progress_bar()
|
||||
print(
|
||||
f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}"
|
||||
)
|
||||
print(f"\n\033[1m[Step {self.current_step}/{self.total_steps}]\033[0m {progress_bar}")
|
||||
self.logger.step(f"{description} (Elapsed: {elapsed:.1f}s)")
|
||||
else:
|
||||
print(f"\n[Step {self.current_step}/{self.total_steps}]")
|
||||
|
||||
@@ -6,8 +6,7 @@ from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
from common import config_helper
|
||||
from common import Logger
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def configure_openai_plugin() -> None:
|
||||
@@ -72,29 +71,19 @@ def configure_openai_plugin() -> None:
|
||||
|
||||
if response.status_code == 200:
|
||||
log.success("OpenAI plugin configured successfully!")
|
||||
log.key_value(
|
||||
"API Base", config_payload["credentials"]["openai_api_base"]
|
||||
)
|
||||
log.key_value(
|
||||
"API Key", config_payload["credentials"]["openai_api_key"]
|
||||
)
|
||||
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
|
||||
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
|
||||
|
||||
elif response.status_code == 201:
|
||||
log.success("OpenAI plugin credentials created successfully!")
|
||||
log.key_value(
|
||||
"API Base", config_payload["credentials"]["openai_api_base"]
|
||||
)
|
||||
log.key_value(
|
||||
"API Key", config_payload["credentials"]["openai_api_key"]
|
||||
)
|
||||
log.key_value("API Base", config_payload["credentials"]["openai_api_base"])
|
||||
log.key_value("API Key", config_payload["credentials"]["openai_api_key"])
|
||||
|
||||
elif response.status_code == 401:
|
||||
log.error("Configuration failed: Unauthorized")
|
||||
log.info("Token may have expired. Please run login_admin.py again")
|
||||
else:
|
||||
log.error(
|
||||
f"Configuration failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"Configuration failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
|
||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from common import config_helper
|
||||
from common import Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def create_api_key() -> None:
|
||||
@@ -90,9 +90,7 @@ def create_api_key() -> None:
|
||||
}
|
||||
|
||||
if config_helper.write_config("api_key_config", api_key_config):
|
||||
log.info(
|
||||
f"API key saved to: {config_helper.get_config_path('benchmark_state')}"
|
||||
)
|
||||
log.info(f"API key saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||
else:
|
||||
log.error("No API token received")
|
||||
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||
@@ -101,9 +99,7 @@ def create_api_key() -> None:
|
||||
log.error("API key creation failed: Unauthorized")
|
||||
log.info("Token may have expired. Please run login_admin.py again")
|
||||
else:
|
||||
log.error(
|
||||
f"API key creation failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"API key creation failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
|
||||
@@ -5,9 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from common import config_helper, Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def import_workflow_app() -> None:
|
||||
@@ -30,7 +31,7 @@ def import_workflow_app() -> None:
|
||||
log.error(f"DSL file not found: {dsl_path}")
|
||||
return
|
||||
|
||||
with open(dsl_path, "r") as f:
|
||||
with open(dsl_path) as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
log.step("Importing workflow app from DSL...")
|
||||
@@ -86,9 +87,7 @@ def import_workflow_app() -> None:
|
||||
log.success("Workflow app imported successfully!")
|
||||
log.key_value("App ID", app_id)
|
||||
log.key_value("App Mode", response_data.get("app_mode"))
|
||||
log.key_value(
|
||||
"DSL Version", response_data.get("imported_dsl_version")
|
||||
)
|
||||
log.key_value("DSL Version", response_data.get("imported_dsl_version"))
|
||||
|
||||
# Save app_id to config
|
||||
app_config = {
|
||||
@@ -99,9 +98,7 @@ def import_workflow_app() -> None:
|
||||
}
|
||||
|
||||
if config_helper.write_config("app_config", app_config):
|
||||
log.info(
|
||||
f"App config saved to: {config_helper.get_config_path('benchmark_state')}"
|
||||
)
|
||||
log.info(f"App config saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||
else:
|
||||
log.error("Import completed but no app_id received")
|
||||
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||
|
||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import time
|
||||
from common import config_helper
|
||||
from common import Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def install_openai_plugin() -> None:
|
||||
@@ -28,9 +28,7 @@ def install_openai_plugin() -> None:
|
||||
|
||||
# API endpoint for plugin installation
|
||||
base_url = "http://localhost:5001"
|
||||
install_endpoint = (
|
||||
f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
|
||||
)
|
||||
install_endpoint = f"{base_url}/console/api/workspaces/current/plugin/install/marketplace"
|
||||
|
||||
# Plugin identifier
|
||||
plugin_payload = {
|
||||
@@ -83,9 +81,7 @@ def install_openai_plugin() -> None:
|
||||
log.info("Polling for task completion...")
|
||||
|
||||
# Poll for task completion
|
||||
task_endpoint = (
|
||||
f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
|
||||
)
|
||||
task_endpoint = f"{base_url}/console/api/workspaces/current/plugin/tasks/{task_id}"
|
||||
|
||||
max_attempts = 30 # 30 attempts with 2 second delay = 60 seconds max
|
||||
attempt = 0
|
||||
@@ -131,9 +127,7 @@ def install_openai_plugin() -> None:
|
||||
plugins = task_info.get("plugins", [])
|
||||
if plugins:
|
||||
for plugin in plugins:
|
||||
log.list_item(
|
||||
f"{plugin.get('plugin_id')}: {plugin.get('message')}"
|
||||
)
|
||||
log.list_item(f"{plugin.get('plugin_id')}: {plugin.get('message')}")
|
||||
break
|
||||
|
||||
# Continue polling if status is "pending" or other
|
||||
@@ -149,9 +143,7 @@ def install_openai_plugin() -> None:
|
||||
log.warning("Plugin may already be installed")
|
||||
log.debug(f"Response: {response.text}")
|
||||
else:
|
||||
log.error(
|
||||
f"Installation failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"Installation failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
|
||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from common import config_helper
|
||||
from common import Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def login_admin() -> None:
|
||||
@@ -77,16 +77,10 @@ def login_admin() -> None:
|
||||
|
||||
# Save token config
|
||||
if config_helper.write_config("token_config", token_config):
|
||||
log.info(
|
||||
f"Token saved to: {config_helper.get_config_path('benchmark_state')}"
|
||||
)
|
||||
log.info(f"Token saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||
|
||||
# Show truncated token for verification
|
||||
token_display = (
|
||||
f"{access_token[:20]}..."
|
||||
if len(access_token) > 20
|
||||
else "Token saved"
|
||||
)
|
||||
token_display = f"{access_token[:20]}..." if len(access_token) > 20 else "Token saved"
|
||||
log.key_value("Access token", token_display)
|
||||
|
||||
elif response.status_code == 401:
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Iterator
|
||||
from flask import Flask, request, jsonify, Response
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, Response, jsonify, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from common import config_helper
|
||||
from common import Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def publish_workflow() -> None:
|
||||
@@ -79,9 +79,7 @@ def publish_workflow() -> None:
|
||||
try:
|
||||
response_data = response.json()
|
||||
if response_data:
|
||||
log.debug(
|
||||
f"Response: {json.dumps(response_data, indent=2)}"
|
||||
)
|
||||
log.debug(f"Response: {json.dumps(response_data, indent=2)}")
|
||||
except json.JSONDecodeError:
|
||||
# Response might be empty or non-JSON
|
||||
pass
|
||||
@@ -93,9 +91,7 @@ def publish_workflow() -> None:
|
||||
log.error("Workflow publish failed: App not found")
|
||||
log.info("Make sure the app was imported successfully")
|
||||
else:
|
||||
log.error(
|
||||
f"Workflow publish failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"Workflow publish failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
|
||||
@@ -5,9 +5,10 @@ from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from common import config_helper, Logger
|
||||
|
||||
import httpx
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def run_workflow(question: str = "fake question", streaming: bool = True) -> None:
|
||||
@@ -70,9 +71,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
||||
event = data.get("event")
|
||||
|
||||
if event == "workflow_started":
|
||||
log.progress(
|
||||
f"Workflow started: {data.get('data', {}).get('id')}"
|
||||
)
|
||||
log.progress(f"Workflow started: {data.get('data', {}).get('id')}")
|
||||
elif event == "node_started":
|
||||
node_data = data.get("data", {})
|
||||
log.progress(
|
||||
@@ -116,9 +115,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
||||
# Some lines might not be JSON
|
||||
pass
|
||||
else:
|
||||
log.error(
|
||||
f"Workflow run failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"Workflow run failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
else:
|
||||
# Handle blocking response
|
||||
@@ -142,9 +139,7 @@ def run_workflow(question: str = "fake question", streaming: bool = True) -> Non
|
||||
log.info("📤 Final Answer:")
|
||||
log.info(outputs.get("answer"), indent=2)
|
||||
else:
|
||||
log.error(
|
||||
f"Workflow run failed with status code: {response.status_code}"
|
||||
)
|
||||
log.error(f"Workflow run failed with status code: {response.status_code}")
|
||||
log.debug(f"Response: {response.text}")
|
||||
|
||||
except httpx.ConnectError:
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import httpx
|
||||
from common import config_helper, Logger
|
||||
from common import Logger, config_helper
|
||||
|
||||
|
||||
def setup_admin_account() -> None:
|
||||
@@ -24,9 +24,7 @@ def setup_admin_account() -> None:
|
||||
|
||||
# Save credentials to config file
|
||||
if config_helper.write_config("admin_config", admin_config):
|
||||
log.info(
|
||||
f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}"
|
||||
)
|
||||
log.info(f"Admin credentials saved to: {config_helper.get_config_path('benchmark_state')}")
|
||||
|
||||
# API setup endpoint
|
||||
base_url = "http://localhost:5001"
|
||||
@@ -56,9 +54,7 @@ def setup_admin_account() -> None:
|
||||
log.key_value("Username", admin_config["username"])
|
||||
|
||||
elif response.status_code == 400:
|
||||
log.warning(
|
||||
"Setup may have already been completed or invalid data provided"
|
||||
)
|
||||
log.warning("Setup may have already been completed or invalid data provided")
|
||||
log.debug(f"Response: {response.text}")
|
||||
else:
|
||||
log.error(f"Setup failed with status code: {response.status_code}")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import socket
|
||||
from pathlib import Path
|
||||
|
||||
from common import Logger, ProgressLogger
|
||||
@@ -93,9 +93,7 @@ def main() -> None:
|
||||
if retry.lower() in ["yes", "y"]:
|
||||
return main() # Recursively call main to check again
|
||||
else:
|
||||
print(
|
||||
"❌ Setup cancelled. Please start the required services and try again."
|
||||
)
|
||||
print("❌ Setup cancelled. Please start the required services and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
log.success("All required services are running!")
|
||||
|
||||
@@ -7,29 +7,28 @@ measuring key metrics like connection rate, event throughput, and time to first
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import threading
|
||||
import os
|
||||
import logging
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from locust import HttpUser, task, between, events, constant
|
||||
from typing import TypedDict, Literal, TypeAlias
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypeAlias, TypedDict
|
||||
|
||||
import requests.exceptions
|
||||
from locust import HttpUser, between, constant, events, task
|
||||
|
||||
# Add the stress-test directory to path to import common modules
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from common.config_helper import ConfigHelper # type: ignore[import-not-found]
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration from environment
|
||||
@@ -54,6 +53,7 @@ ErrorType: TypeAlias = Literal[
|
||||
|
||||
class ErrorCounts(TypedDict):
|
||||
"""Error count tracking"""
|
||||
|
||||
connection_error: int
|
||||
timeout: int
|
||||
invalid_json: int
|
||||
@@ -65,6 +65,7 @@ class ErrorCounts(TypedDict):
|
||||
|
||||
class SSEEvent(TypedDict):
|
||||
"""Server-Sent Event structure"""
|
||||
|
||||
data: str
|
||||
event: str
|
||||
id: str | None
|
||||
@@ -72,11 +73,13 @@ class SSEEvent(TypedDict):
|
||||
|
||||
class WorkflowInputs(TypedDict):
|
||||
"""Workflow input structure"""
|
||||
|
||||
question: str
|
||||
|
||||
|
||||
class WorkflowRequestData(TypedDict):
|
||||
"""Workflow request payload"""
|
||||
|
||||
inputs: WorkflowInputs
|
||||
response_mode: Literal["streaming"]
|
||||
user: str
|
||||
@@ -84,6 +87,7 @@ class WorkflowRequestData(TypedDict):
|
||||
|
||||
class ParsedEventData(TypedDict, total=False):
|
||||
"""Parsed event data from SSE stream"""
|
||||
|
||||
event: str
|
||||
task_id: str
|
||||
workflow_run_id: str
|
||||
@@ -93,6 +97,7 @@ class ParsedEventData(TypedDict, total=False):
|
||||
|
||||
class LocustStats(TypedDict):
|
||||
"""Locust statistics structure"""
|
||||
|
||||
total_requests: int
|
||||
total_failures: int
|
||||
avg_response_time: float
|
||||
@@ -102,6 +107,7 @@ class LocustStats(TypedDict):
|
||||
|
||||
class ReportData(TypedDict):
|
||||
"""JSON report structure"""
|
||||
|
||||
timestamp: str
|
||||
duration_seconds: float
|
||||
metrics: dict[str, object] # Metrics as dict for JSON serialization
|
||||
@@ -154,7 +160,7 @@ class MetricsTracker:
|
||||
self.total_connections = 0
|
||||
self.total_events = 0
|
||||
self.start_time = time.time()
|
||||
|
||||
|
||||
# Enhanced metrics with memory limits
|
||||
self.max_samples = 10000 # Prevent unbounded growth
|
||||
self.ttfe_samples: deque[float] = deque(maxlen=self.max_samples)
|
||||
@@ -233,9 +239,7 @@ class MetricsTracker:
|
||||
max_ttfe = max(self.ttfe_samples)
|
||||
p50_ttfe = statistics.median(self.ttfe_samples)
|
||||
if len(self.ttfe_samples) >= 2:
|
||||
quantiles = statistics.quantiles(
|
||||
self.ttfe_samples, n=20, method="inclusive"
|
||||
)
|
||||
quantiles = statistics.quantiles(self.ttfe_samples, n=20, method="inclusive")
|
||||
p95_ttfe = quantiles[18] # 19th of 19 quantiles = 95th percentile
|
||||
else:
|
||||
p95_ttfe = max_ttfe
|
||||
@@ -255,9 +259,7 @@ class MetricsTracker:
|
||||
if durations
|
||||
else 0
|
||||
)
|
||||
events_per_stream_avg = (
|
||||
statistics.mean(events_per_stream) if events_per_stream else 0
|
||||
)
|
||||
events_per_stream_avg = statistics.mean(events_per_stream) if events_per_stream else 0
|
||||
|
||||
# Calculate inter-event latency statistics
|
||||
all_inter_event_times = []
|
||||
@@ -268,32 +270,20 @@ class MetricsTracker:
|
||||
inter_event_latency_avg = statistics.mean(all_inter_event_times)
|
||||
inter_event_latency_p50 = statistics.median(all_inter_event_times)
|
||||
inter_event_latency_p95 = (
|
||||
statistics.quantiles(
|
||||
all_inter_event_times, n=20, method="inclusive"
|
||||
)[18]
|
||||
statistics.quantiles(all_inter_event_times, n=20, method="inclusive")[18]
|
||||
if len(all_inter_event_times) >= 2
|
||||
else max(all_inter_event_times)
|
||||
)
|
||||
else:
|
||||
inter_event_latency_avg = inter_event_latency_p50 = (
|
||||
inter_event_latency_p95
|
||||
) = 0
|
||||
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
|
||||
else:
|
||||
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = (
|
||||
events_per_stream_avg
|
||||
) = 0
|
||||
inter_event_latency_avg = inter_event_latency_p50 = (
|
||||
inter_event_latency_p95
|
||||
) = 0
|
||||
stream_duration_avg = stream_duration_p50 = stream_duration_p95 = events_per_stream_avg = 0
|
||||
inter_event_latency_avg = inter_event_latency_p50 = inter_event_latency_p95 = 0
|
||||
|
||||
# Also calculate overall average rates
|
||||
total_elapsed = current_time - self.start_time
|
||||
overall_conn_rate = (
|
||||
self.total_connections / total_elapsed if total_elapsed > 0 else 0
|
||||
)
|
||||
overall_event_rate = (
|
||||
self.total_events / total_elapsed if total_elapsed > 0 else 0
|
||||
)
|
||||
overall_conn_rate = self.total_connections / total_elapsed if total_elapsed > 0 else 0
|
||||
overall_event_rate = self.total_events / total_elapsed if total_elapsed > 0 else 0
|
||||
|
||||
return MetricsSnapshot(
|
||||
active_connections=self.active_connections,
|
||||
@@ -389,7 +379,7 @@ class DifyWorkflowUser(HttpUser):
|
||||
|
||||
# Load questions from file or use defaults
|
||||
if QUESTIONS_FILE and os.path.exists(QUESTIONS_FILE):
|
||||
with open(QUESTIONS_FILE, "r") as f:
|
||||
with open(QUESTIONS_FILE) as f:
|
||||
self.questions = [line.strip() for line in f if line.strip()]
|
||||
else:
|
||||
self.questions = [
|
||||
@@ -451,18 +441,13 @@ class DifyWorkflowUser(HttpUser):
|
||||
try:
|
||||
# Validate response
|
||||
if response.status_code >= 400:
|
||||
error_type: ErrorType = (
|
||||
"http_4xx" if response.status_code < 500 else "http_5xx"
|
||||
)
|
||||
error_type: ErrorType = "http_4xx" if response.status_code < 500 else "http_5xx"
|
||||
metrics.record_error(error_type)
|
||||
response.failure(f"HTTP {response.status_code}")
|
||||
return
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if (
|
||||
"text/event-stream" not in content_type
|
||||
and "application/json" not in content_type
|
||||
):
|
||||
if "text/event-stream" not in content_type and "application/json" not in content_type:
|
||||
logger.error(f"Expected text/event-stream, got: {content_type}")
|
||||
metrics.record_error("invalid_response")
|
||||
response.failure(f"Invalid content type: {content_type}")
|
||||
@@ -473,10 +458,13 @@ class DifyWorkflowUser(HttpUser):
|
||||
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
# Check if runner is stopping
|
||||
if getattr(self.environment.runner, 'state', '') in ('stopping', 'stopped'):
|
||||
if getattr(self.environment.runner, "state", "") in (
|
||||
"stopping",
|
||||
"stopped",
|
||||
):
|
||||
logger.debug("Runner stopping, breaking streaming loop")
|
||||
break
|
||||
|
||||
|
||||
if line is not None:
|
||||
bytes_received += len(line.encode("utf-8"))
|
||||
|
||||
@@ -489,9 +477,7 @@ class DifyWorkflowUser(HttpUser):
|
||||
|
||||
# Track inter-event timing
|
||||
if last_event_time:
|
||||
inter_event_times.append(
|
||||
(current_time - last_event_time) * 1000
|
||||
)
|
||||
inter_event_times.append((current_time - last_event_time) * 1000)
|
||||
last_event_time = current_time
|
||||
|
||||
if first_event_time is None:
|
||||
@@ -512,15 +498,11 @@ class DifyWorkflowUser(HttpUser):
|
||||
parsed_event: ParsedEventData = json.loads(event_data)
|
||||
# Check for terminal events
|
||||
if parsed_event.get("event") in TERMINAL_EVENTS:
|
||||
logger.debug(
|
||||
f"Received terminal event: {parsed_event.get('event')}"
|
||||
)
|
||||
logger.debug(f"Received terminal event: {parsed_event.get('event')}")
|
||||
request_success = True
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"JSON decode error: {e} for data: {event_data[:100]}"
|
||||
)
|
||||
logger.debug(f"JSON decode error: {e} for data: {event_data[:100]}")
|
||||
metrics.record_error("invalid_json")
|
||||
|
||||
except Exception as e:
|
||||
@@ -583,16 +565,18 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
||||
|
||||
# Periodic stats reporting
|
||||
def report_stats() -> None:
|
||||
if not hasattr(environment, 'runner'):
|
||||
if not hasattr(environment, "runner"):
|
||||
return
|
||||
runner = environment.runner
|
||||
while hasattr(runner, 'state') and runner.state not in ["stopped", "stopping"]:
|
||||
while hasattr(runner, "state") and runner.state not in ["stopped", "stopping"]:
|
||||
time.sleep(5) # Report every 5 seconds
|
||||
if hasattr(runner, 'state') and runner.state == "running":
|
||||
if hasattr(runner, "state") and runner.state == "running":
|
||||
stats = metrics.get_stats()
|
||||
|
||||
# Only log on master node in distributed mode
|
||||
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, 'runner') else True
|
||||
is_master = (
|
||||
not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
|
||||
)
|
||||
if is_master:
|
||||
# Clear previous lines and show updated stats
|
||||
logger.info("\n" + "=" * 80)
|
||||
@@ -623,15 +607,15 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
||||
logger.info(
|
||||
f"{'(TTFE in ms)':<25} {stats.ttfe_avg:>15.1f} {stats.ttfe_p50:>10.1f} {stats.ttfe_p95:>10.1f} {stats.ttfe_min:>10.1f} {stats.ttfe_max:>10.1f}"
|
||||
)
|
||||
logger.info(f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)")
|
||||
logger.info(
|
||||
f"{'Window Samples':<25} {stats.ttfe_samples:>15,d} (last {min(10000, stats.ttfe_total_samples):,d} samples)"
|
||||
)
|
||||
logger.info(f"{'Total Samples':<25} {stats.ttfe_total_samples:>15,d}")
|
||||
|
||||
# Inter-event latency
|
||||
if stats.inter_event_latency_avg > 0:
|
||||
logger.info("-" * 80)
|
||||
logger.info(
|
||||
f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}"
|
||||
)
|
||||
logger.info(f"{'INTER-EVENT LATENCY':<25} {'AVG':>15} {'P50':>10} {'P95':>10}")
|
||||
logger.info(
|
||||
f"{'(ms between events)':<25} {stats.inter_event_latency_avg:>15.1f} {stats.inter_event_latency_p50:>10.1f} {stats.inter_event_latency_p95:>10.1f}"
|
||||
)
|
||||
@@ -647,9 +631,9 @@ def on_test_start(environment: object, **kwargs: object) -> None:
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Show Locust stats summary
|
||||
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
|
||||
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
|
||||
total = environment.stats.total
|
||||
if hasattr(total, 'num_requests') and total.num_requests > 0:
|
||||
if hasattr(total, "num_requests") and total.num_requests > 0:
|
||||
logger.info(
|
||||
f"{'LOCUST STATS':<25} {'Requests':>12} {'Fails':>8} {'Avg (ms)':>12} {'Min':>8} {'Max':>8}"
|
||||
)
|
||||
@@ -687,21 +671,15 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
||||
logger.info("")
|
||||
logger.info("EVENTS")
|
||||
logger.info(f" {'Total Events Received:':<30} {stats.total_events:>10,d}")
|
||||
logger.info(
|
||||
f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s"
|
||||
)
|
||||
logger.info(
|
||||
f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s"
|
||||
)
|
||||
logger.info(f" {'Average Throughput:':<30} {stats.overall_event_rate:>10.2f} events/s")
|
||||
logger.info(f" {'Final Rate (10s window):':<30} {stats.event_rate:>10.2f} events/s")
|
||||
|
||||
logger.info("")
|
||||
logger.info("STREAM METRICS")
|
||||
logger.info(f" {'Avg Stream Duration:':<30} {stats.stream_duration_avg:>10.1f} ms")
|
||||
logger.info(f" {'P50 Stream Duration:':<30} {stats.stream_duration_p50:>10.1f} ms")
|
||||
logger.info(f" {'P95 Stream Duration:':<30} {stats.stream_duration_p95:>10.1f} ms")
|
||||
logger.info(
|
||||
f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}"
|
||||
)
|
||||
logger.info(f" {'Avg Events per Stream:':<30} {stats.events_per_stream_avg:>10.1f}")
|
||||
|
||||
logger.info("")
|
||||
logger.info("INTER-EVENT LATENCY")
|
||||
@@ -716,7 +694,9 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
||||
logger.info(f" {'95th Percentile:':<30} {stats.ttfe_p95:>10.1f} ms")
|
||||
logger.info(f" {'Minimum:':<30} {stats.ttfe_min:>10.1f} ms")
|
||||
logger.info(f" {'Maximum:':<30} {stats.ttfe_max:>10.1f} ms")
|
||||
logger.info(f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})")
|
||||
logger.info(
|
||||
f" {'Window Samples:':<30} {stats.ttfe_samples:>10,d} (last {min(10000, stats.ttfe_total_samples):,d})"
|
||||
)
|
||||
logger.info(f" {'Total Samples:':<30} {stats.ttfe_total_samples:>10,d}")
|
||||
|
||||
# Error summary
|
||||
@@ -730,7 +710,7 @@ def on_test_stop(environment: object, **kwargs: object) -> None:
|
||||
logger.info("=" * 80 + "\n")
|
||||
|
||||
# Export machine-readable report (only on master node)
|
||||
is_master = not getattr(environment.runner, 'worker_id', None) if hasattr(environment, 'runner') else True
|
||||
is_master = not getattr(environment.runner, "worker_id", None) if hasattr(environment, "runner") else True
|
||||
if is_master:
|
||||
export_json_report(stats, test_duration, environment)
|
||||
|
||||
@@ -746,9 +726,9 @@ def export_json_report(stats: MetricsSnapshot, duration: float, environment: obj
|
||||
|
||||
# Access environment.stats.total attributes safely
|
||||
locust_stats: LocustStats | None = None
|
||||
if hasattr(environment, 'stats') and hasattr(environment.stats, 'total'):
|
||||
if hasattr(environment, "stats") and hasattr(environment.stats, "total"):
|
||||
total = environment.stats.total
|
||||
if hasattr(total, 'num_requests') and total.num_requests > 0:
|
||||
if hasattr(total, "num_requests") and total.num_requests > 0:
|
||||
locust_stats = LocustStats(
|
||||
total_requests=total.num_requests,
|
||||
total_failures=total.num_failures,
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from dify_client.client import (
|
||||
ChatClient,
|
||||
CompletionClient,
|
||||
WorkflowClient,
|
||||
KnowledgeBaseClient,
|
||||
DifyClient,
|
||||
KnowledgeBaseClient,
|
||||
WorkflowClient,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatClient",
|
||||
"CompletionClient",
|
||||
"DifyClient",
|
||||
"KnowledgeBaseClient",
|
||||
"WorkflowClient",
|
||||
]
|
||||
|
||||
@@ -8,16 +8,16 @@ class DifyClient:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def _send_request(self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False):
|
||||
def _send_request(
|
||||
self, method: str, endpoint: str, json: dict | None = None, params: dict | None = None, stream: bool = False
|
||||
):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method, url, json=json, params=params, headers=headers, stream=stream
|
||||
)
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers, stream=stream)
|
||||
|
||||
return response
|
||||
|
||||
@@ -25,9 +25,7 @@ class DifyClient:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = requests.request(
|
||||
method, url, data=data, headers=headers, files=files
|
||||
)
|
||||
response = requests.request(method, url, data=data, headers=headers, files=files)
|
||||
|
||||
return response
|
||||
|
||||
@@ -41,9 +39,7 @@ class DifyClient:
|
||||
|
||||
def file_upload(self, user: str, files: dict):
|
||||
data = {"user": user}
|
||||
return self._send_request_with_files(
|
||||
"POST", "/files/upload", data=data, files=files
|
||||
)
|
||||
return self._send_request_with_files("POST", "/files/upload", data=data, files=files)
|
||||
|
||||
def text_to_audio(self, text: str, user: str, streaming: bool = False):
|
||||
data = {"text": text, "user": user, "streaming": streaming}
|
||||
@@ -55,7 +51,9 @@ class DifyClient:
|
||||
|
||||
|
||||
class CompletionClient(DifyClient):
|
||||
def create_completion_message(self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None):
|
||||
def create_completion_message(
|
||||
self, inputs: dict, response_mode: Literal["blocking", "streaming"], user: str, files: dict | None = None
|
||||
):
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"response_mode": response_mode,
|
||||
@@ -99,9 +97,7 @@ class ChatClient(DifyClient):
|
||||
|
||||
def get_suggested(self, message_id: str, user: str):
|
||||
params = {"user": user}
|
||||
return self._send_request(
|
||||
"GET", f"/messages/{message_id}/suggested", params=params
|
||||
)
|
||||
return self._send_request("GET", f"/messages/{message_id}/suggested", params=params)
|
||||
|
||||
def stop_message(self, task_id: str, user: str):
|
||||
data = {"user": user}
|
||||
@@ -112,10 +108,9 @@ class ChatClient(DifyClient):
|
||||
user: str,
|
||||
last_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
pinned: bool | None = None
|
||||
pinned: bool | None = None,
|
||||
):
|
||||
params = {"user": user, "last_id": last_id,
|
||||
"limit": limit, "pinned": pinned}
|
||||
params = {"user": user, "last_id": last_id, "limit": limit, "pinned": pinned}
|
||||
return self._send_request("GET", "/conversations", params=params)
|
||||
|
||||
def get_conversation_messages(
|
||||
@@ -123,7 +118,7 @@ class ChatClient(DifyClient):
|
||||
user: str,
|
||||
conversation_id: str | None = None,
|
||||
first_id: str | None = None,
|
||||
limit: int | None = None
|
||||
limit: int | None = None,
|
||||
):
|
||||
params = {"user": user}
|
||||
|
||||
@@ -136,13 +131,9 @@ class ChatClient(DifyClient):
|
||||
|
||||
return self._send_request("GET", "/messages", params=params)
|
||||
|
||||
def rename_conversation(
|
||||
self, conversation_id: str, name: str, auto_generate: bool, user: str
|
||||
):
|
||||
def rename_conversation(self, conversation_id: str, name: str, auto_generate: bool, user: str):
|
||||
data = {"name": name, "auto_generate": auto_generate, "user": user}
|
||||
return self._send_request(
|
||||
"POST", f"/conversations/{conversation_id}/name", data
|
||||
)
|
||||
return self._send_request("POST", f"/conversations/{conversation_id}/name", data)
|
||||
|
||||
def delete_conversation(self, conversation_id: str, user: str):
|
||||
data = {"user": user}
|
||||
@@ -155,9 +146,7 @@ class ChatClient(DifyClient):
|
||||
|
||||
|
||||
class WorkflowClient(DifyClient):
|
||||
def run(
|
||||
self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"
|
||||
):
|
||||
def run(self, inputs: dict, response_mode: Literal["blocking", "streaming"] = "streaming", user: str = "abc-123"):
|
||||
data = {"inputs": inputs, "response_mode": response_mode, "user": user}
|
||||
return self._send_request("POST", "/workflows/run", data)
|
||||
|
||||
@@ -197,13 +186,9 @@ class KnowledgeBaseClient(DifyClient):
|
||||
return self._send_request("POST", "/datasets", {"name": name}, **kwargs)
|
||||
|
||||
def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
|
||||
return self._send_request(
|
||||
"GET", f"/datasets?page={page}&limit={page_size}", **kwargs
|
||||
)
|
||||
return self._send_request("GET", f"/datasets?page={page}&limit={page_size}", **kwargs)
|
||||
|
||||
def create_document_by_text(
|
||||
self, name, text, extra_params: dict | None = None, **kwargs
|
||||
):
|
||||
def create_document_by_text(self, name, text, extra_params: dict | None = None, **kwargs):
|
||||
"""
|
||||
Create a document by text.
|
||||
|
||||
@@ -272,9 +257,7 @@ class KnowledgeBaseClient(DifyClient):
|
||||
data = {"name": name, "text": text}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = (
|
||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||
)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
|
||||
return self._send_request("POST", url, json=data, **kwargs)
|
||||
|
||||
def create_document_by_file(
|
||||
@@ -315,13 +298,9 @@ class KnowledgeBaseClient(DifyClient):
|
||||
if original_document_id is not None:
|
||||
data["original_document_id"] = original_document_id
|
||||
url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
|
||||
return self._send_request_with_files(
|
||||
"POST", url, {"data": json.dumps(data)}, files
|
||||
)
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
def update_document_by_file(
|
||||
self, document_id: str, file_path: str, extra_params: dict | None = None
|
||||
):
|
||||
def update_document_by_file(self, document_id: str, file_path: str, extra_params: dict | None = None):
|
||||
"""
|
||||
Update a document by file.
|
||||
|
||||
@@ -351,12 +330,8 @@ class KnowledgeBaseClient(DifyClient):
|
||||
data = {}
|
||||
if extra_params is not None and isinstance(extra_params, dict):
|
||||
data.update(extra_params)
|
||||
url = (
|
||||
f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
)
|
||||
return self._send_request_with_files(
|
||||
"POST", url, {"data": json.dumps(data)}, files
|
||||
)
|
||||
url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
|
||||
return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
|
||||
|
||||
def batch_indexing_status(self, batch_id: str, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from setuptools import setup
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
with open("README.md", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(
|
||||
|
||||
@@ -18,9 +18,7 @@ FILE_PATH_BASE = os.path.dirname(__file__)
|
||||
class TestKnowledgeBaseClient(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
|
||||
self.README_FILE_PATH = os.path.abspath(
|
||||
os.path.join(FILE_PATH_BASE, "../README.md")
|
||||
)
|
||||
self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
|
||||
self.dataset_id = None
|
||||
self.document_id = None
|
||||
self.segment_id = None
|
||||
@@ -28,9 +26,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
||||
|
||||
def _get_dataset_kb_client(self):
|
||||
self.assertIsNotNone(self.dataset_id)
|
||||
return KnowledgeBaseClient(
|
||||
API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id
|
||||
)
|
||||
return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
|
||||
|
||||
def test_001_create_dataset(self):
|
||||
response = self.knowledge_base_client.create_dataset(name="test_dataset")
|
||||
@@ -76,9 +72,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
||||
def _test_004_update_document_by_text(self):
|
||||
client = self._get_dataset_kb_client()
|
||||
self.assertIsNotNone(self.document_id)
|
||||
response = client.update_document_by_text(
|
||||
self.document_id, "test_document_updated", "test_text_updated"
|
||||
)
|
||||
response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
|
||||
data = response.json()
|
||||
self.assertIn("document", data)
|
||||
self.assertIn("batch", data)
|
||||
@@ -93,9 +87,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
||||
def _test_006_update_document_by_file(self):
|
||||
client = self._get_dataset_kb_client()
|
||||
self.assertIsNotNone(self.document_id)
|
||||
response = client.update_document_by_file(
|
||||
self.document_id, self.README_FILE_PATH
|
||||
)
|
||||
response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
|
||||
data = response.json()
|
||||
self.assertIn("document", data)
|
||||
self.assertIn("batch", data)
|
||||
@@ -125,9 +117,7 @@ class TestKnowledgeBaseClient(unittest.TestCase):
|
||||
|
||||
def _test_010_add_segments(self):
|
||||
client = self._get_dataset_kb_client()
|
||||
response = client.add_segments(
|
||||
self.document_id, [{"content": "test text segment 1"}]
|
||||
)
|
||||
response = client.add_segments(self.document_id, [{"content": "test text segment 1"}])
|
||||
data = response.json()
|
||||
self.assertIn("data", data)
|
||||
self.assertGreater(len(data["data"]), 0)
|
||||
@@ -174,18 +164,12 @@ class TestChatClient(unittest.TestCase):
|
||||
self.chat_client = ChatClient(API_KEY)
|
||||
|
||||
def test_create_chat_message(self):
|
||||
response = self.chat_client.create_chat_message(
|
||||
{}, "Hello, World!", "test_user"
|
||||
)
|
||||
response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user")
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
def test_create_chat_message_with_vision_model_by_remote_url(self):
|
||||
files = [
|
||||
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
|
||||
]
|
||||
response = self.chat_client.create_chat_message(
|
||||
{}, "Describe the picture.", "test_user", files=files
|
||||
)
|
||||
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
|
||||
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
def test_create_chat_message_with_vision_model_by_local_file(self):
|
||||
@@ -196,15 +180,11 @@ class TestChatClient(unittest.TestCase):
|
||||
"upload_file_id": "your_file_id",
|
||||
}
|
||||
]
|
||||
response = self.chat_client.create_chat_message(
|
||||
{}, "Describe the picture.", "test_user", files=files
|
||||
)
|
||||
response = self.chat_client.create_chat_message({}, "Describe the picture.", "test_user", files=files)
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
def test_get_conversation_messages(self):
|
||||
response = self.chat_client.get_conversation_messages(
|
||||
"test_user", "your_conversation_id"
|
||||
)
|
||||
response = self.chat_client.get_conversation_messages("test_user", "your_conversation_id")
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
def test_get_conversations(self):
|
||||
@@ -223,9 +203,7 @@ class TestCompletionClient(unittest.TestCase):
|
||||
self.assertIn("answer", response.text)
|
||||
|
||||
def test_create_completion_message_with_vision_model_by_remote_url(self):
|
||||
files = [
|
||||
{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}
|
||||
]
|
||||
files = [{"type": "image", "transfer_method": "remote_url", "url": "your_image_url"}]
|
||||
response = self.completion_client.create_completion_message(
|
||||
{"query": "Describe the picture."}, "blocking", "test_user", files
|
||||
)
|
||||
@@ -250,9 +228,7 @@ class TestDifyClient(unittest.TestCase):
|
||||
self.dify_client = DifyClient(API_KEY)
|
||||
|
||||
def test_message_feedback(self):
|
||||
response = self.dify_client.message_feedback(
|
||||
"your_message_id", "like", "test_user"
|
||||
)
|
||||
response = self.dify_client.message_feedback("your_message_id", "like", "test_user")
|
||||
self.assertIn("success", response.text)
|
||||
|
||||
def test_get_application_parameters(self):
|
||||
|
||||
@@ -26,9 +26,16 @@ export const useTags = (translateFromOut?: TFunction) => {
|
||||
return acc
|
||||
}, {} as Record<string, Tag>)
|
||||
|
||||
const getTagLabel = (name: string) => {
|
||||
if (!tagsMap[name])
|
||||
return name
|
||||
return tagsMap[name].label
|
||||
}
|
||||
|
||||
return {
|
||||
tags,
|
||||
tagsMap,
|
||||
getTagLabel,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ const CardWrapper = ({
|
||||
setFalse: hideInstallFromMarketplace,
|
||||
}] = useBoolean(false)
|
||||
const { locale: localeFromLocale } = useI18N()
|
||||
const { tagsMap } = useTags(t)
|
||||
const { getTagLabel } = useTags(t)
|
||||
|
||||
if (showInstallButton) {
|
||||
return (
|
||||
@@ -43,7 +43,7 @@ const CardWrapper = ({
|
||||
footer={
|
||||
<CardMoreInfo
|
||||
downloadCount={plugin.install_count}
|
||||
tags={plugin.tags.map(tag => tagsMap[tag.name].label)}
|
||||
tags={plugin.tags.map(tag => getTagLabel(tag.name))}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
@@ -92,7 +92,7 @@ const CardWrapper = ({
|
||||
footer={
|
||||
<CardMoreInfo
|
||||
downloadCount={plugin.install_count}
|
||||
tags={plugin.tags.map(tag => tagsMap[tag.name].label)}
|
||||
tags={plugin.tags.map(tag => getTagLabel(tag.name))}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
PluginsSearchParams,
|
||||
} from '@/app/components/plugins/marketplace/types'
|
||||
import {
|
||||
APP_VERSION,
|
||||
MARKETPLACE_API_PREFIX,
|
||||
} from '@/config'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
@@ -49,11 +50,15 @@ export const getMarketplacePluginsByCollectionId = async (collectionId: string,
|
||||
|
||||
try {
|
||||
const url = `${MARKETPLACE_API_PREFIX}/collections/${collectionId}/plugins`
|
||||
const headers = new Headers({
|
||||
'X-Dify-Version': APP_VERSION,
|
||||
})
|
||||
const marketplaceCollectionPluginsData = await globalThis.fetch(
|
||||
url,
|
||||
{
|
||||
cache: 'no-store',
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
category: query?.category,
|
||||
exclude: query?.exclude,
|
||||
@@ -83,7 +88,10 @@ export const getMarketplaceCollectionsAndPlugins = async (query?: CollectionsAnd
|
||||
marketplaceUrl += `&condition=${query.condition}`
|
||||
if (query?.type)
|
||||
marketplaceUrl += `&type=${query.type}`
|
||||
const marketplaceCollectionsData = await globalThis.fetch(marketplaceUrl, { cache: 'no-store' })
|
||||
const headers = new Headers({
|
||||
'X-Dify-Version': APP_VERSION,
|
||||
})
|
||||
const marketplaceCollectionsData = await globalThis.fetch(marketplaceUrl, { headers, cache: 'no-store' })
|
||||
const marketplaceCollectionsDataJson = await marketplaceCollectionsData.json()
|
||||
marketplaceCollections = marketplaceCollectionsDataJson.data.collections
|
||||
await Promise.all(marketplaceCollections.map(async (collection: MarketplaceCollection) => {
|
||||
|
||||
@@ -27,7 +27,7 @@ const TagsFilter = ({
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const [searchText, setSearchText] = useState('')
|
||||
const { tags: options, tagsMap } = useTags()
|
||||
const { tags: options, getTagLabel } = useTags()
|
||||
const filteredOptions = options.filter(option => option.name.toLowerCase().includes(searchText.toLowerCase()))
|
||||
const handleCheck = (id: string) => {
|
||||
if (value.includes(id))
|
||||
@@ -59,7 +59,7 @@ const TagsFilter = ({
|
||||
!selectedTagsLength && t('pluginTags.allTags')
|
||||
}
|
||||
{
|
||||
!!selectedTagsLength && value.map(val => tagsMap[val].label).slice(0, 2).join(',')
|
||||
!!selectedTagsLength && value.map(val => getTagLabel(val)).slice(0, 2).join(',')
|
||||
}
|
||||
{
|
||||
selectedTagsLength > 2 && (
|
||||
|
||||
@@ -47,14 +47,14 @@ const useConfig = (id: string, payload: LoopNodeType) => {
|
||||
})
|
||||
|
||||
const changeErrorResponseMode = useCallback((item: { value: unknown }) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
draft.error_handle_mode = item.value as ErrorHandleMode
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [inputs, handleInputsChange])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>((valueSelector, varItem) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
if (!draft.break_conditions)
|
||||
draft.break_conditions = []
|
||||
|
||||
@@ -66,34 +66,34 @@ const useConfig = (id: string, payload: LoopNodeType) => {
|
||||
value: varItem.type === VarType.boolean ? 'false' : '',
|
||||
})
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [getIsVarFileAttribute, inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [getIsVarFileAttribute, handleInputsChange])
|
||||
|
||||
const handleRemoveCondition = useCallback<HandleRemoveCondition>((conditionId) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
draft.break_conditions = draft.break_conditions?.filter(item => item.id !== conditionId)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleUpdateCondition = useCallback<HandleUpdateCondition>((conditionId, newCondition) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
|
||||
if (targetCondition)
|
||||
Object.assign(targetCondition, newCondition)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleToggleConditionLogicalOperator = useCallback<HandleToggleConditionLogicalOperator>(() => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
draft.logical_operator = draft.logical_operator === LogicalOperator.and ? LogicalOperator.or : LogicalOperator.and
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleAddSubVariableCondition = useCallback<HandleAddSubVariableCondition>((conditionId: string, key?: string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
const condition = draft.break_conditions?.find(item => item.id === conditionId)
|
||||
if (!condition)
|
||||
return
|
||||
@@ -119,11 +119,11 @@ const useConfig = (id: string, payload: LoopNodeType) => {
|
||||
})
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleRemoveSubVariableCondition = useCallback((conditionId: string, subConditionId: string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
const condition = draft.break_conditions?.find(item => item.id === conditionId)
|
||||
if (!condition)
|
||||
return
|
||||
@@ -133,11 +133,11 @@ const useConfig = (id: string, payload: LoopNodeType) => {
|
||||
if (subVarCondition)
|
||||
subVarCondition.conditions = subVarCondition.conditions.filter(item => item.id !== subConditionId)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleUpdateSubVariableCondition = useCallback<HandleUpdateSubVariableCondition>((conditionId, subConditionId, newSubCondition) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
|
||||
if (targetCondition && targetCondition.sub_variable_condition) {
|
||||
const targetSubCondition = targetCondition.sub_variable_condition.conditions.find(item => item.id === subConditionId)
|
||||
@@ -145,24 +145,24 @@ const useConfig = (id: string, payload: LoopNodeType) => {
|
||||
Object.assign(targetSubCondition, newSubCondition)
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleToggleSubVariableConditionLogicalOperator = useCallback<HandleToggleSubVariableConditionLogicalOperator>((conditionId) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
const targetCondition = draft.break_conditions?.find(item => item.id === conditionId)
|
||||
if (targetCondition && targetCondition.sub_variable_condition)
|
||||
targetCondition.sub_variable_condition.logical_operator = targetCondition.sub_variable_condition.logical_operator === LogicalOperator.and ? LogicalOperator.or : LogicalOperator.and
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleUpdateLoopCount = useCallback((value: number) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
draft.loop_count = value
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
handleInputsChange(newInputs)
|
||||
}, [handleInputsChange])
|
||||
|
||||
const handleAddLoopVariable = useCallback(() => {
|
||||
const newInputs = produce(inputsRef.current, (draft) => {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { InputVarType } from '@/app/components/workflow/types'
|
||||
import { AgentStrategy } from '@/types/app'
|
||||
import { PromptRole } from '@/models/debug'
|
||||
import { DatasetAttr } from '@/types/feature'
|
||||
import pkg from '../package.json'
|
||||
|
||||
const getBooleanConfig = (envVar: string | undefined, dataAttrKey: DatasetAttr, defaultValue: boolean = true) => {
|
||||
if (envVar !== undefined && envVar !== '')
|
||||
@@ -294,3 +295,5 @@ export const ZENDESK_FIELD_IDS = {
|
||||
WORKSPACE_ID: getStringConfig(process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID, DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID, ''),
|
||||
PLAN: getStringConfig(process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN, DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN, ''),
|
||||
}
|
||||
|
||||
export const APP_VERSION = pkg.version
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { AfterResponseHook, BeforeErrorHook, BeforeRequestHook, Hooks } fro
|
||||
import ky from 'ky'
|
||||
import type { IOtherOptions } from './base'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
|
||||
import { API_PREFIX, APP_VERSION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
|
||||
import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils'
|
||||
import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils'
|
||||
|
||||
@@ -151,6 +151,10 @@ async function base<T>(url: string, options: FetchOptionType = {}, otherOptions:
|
||||
if (deleteContentType)
|
||||
(headers as any).delete('Content-Type')
|
||||
|
||||
// ! For Marketplace API, help to filter tags added in new version
|
||||
if (isMarketplaceAPI)
|
||||
(headers as any).set('X-Dify-Version', APP_VERSION)
|
||||
|
||||
const client = baseClient.extend({
|
||||
hooks: {
|
||||
...baseHooks,
|
||||
|
||||
@@ -3,8 +3,8 @@ html[data-theme="dark"] {
|
||||
rgba(34, 34, 37, 0.9) 0%,
|
||||
rgba(29, 29, 32, 0.9) 90.48%);
|
||||
--color-chat-bubble-bg: linear-gradient(180deg,
|
||||
rgba(200, 206, 218, 0.08) 0%,
|
||||
rgba(200, 206, 218, 0.02) 100%);
|
||||
rgb(42, 43, 48) 0%,
|
||||
rgb(37, 38, 42) 100%);
|
||||
--color-chat-input-mask: linear-gradient(180deg,
|
||||
rgba(24, 24, 27, 0.04) 0%,
|
||||
rgba(24, 24, 27, 0.60) 100%);
|
||||
|
||||
Reference in New Issue
Block a user