Compare commits

..

56 Commits

Author SHA1 Message Date
Joe
c12596af48 fix: trace_app_config_app_id_idx 2024-07-20 01:14:46 +08:00
Joe
27e08a8e2e Fix/extra table tracing app config (#6487) 2024-07-20 00:53:31 +08:00
Matri
49ef9ef225 feat(tool): getimg.ai integration (#6260) 2024-07-19 20:32:42 +08:00
Even
c013086e64 fix: next suggest question logic problem (#6451)
Co-authored-by: evenyan <yikun.yan@ubtrobot.com>
2024-07-19 20:26:11 +08:00
crazywoola
48f872a68c fix: build error (#6480) 2024-07-19 18:37:42 +08:00
sino
4f9f175f25 fix: correct gpt-4o-mini max token (#6472)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-07-19 18:24:58 +08:00
moqimoqidea
47e5dc218a Update CONTRIBUTING_CN "安装常见问题解答" link. (#6470) 2024-07-19 17:06:32 +08:00
moqimoqidea
90372932fe Update CONTRIBUTING "installation FAQ" link. (#6471) 2024-07-19 17:05:30 +08:00
moqimoqidea
0bb2b285da Update CONTRIBUTING_JA "installation FAQ" link. (#6469) 2024-07-19 17:05:20 +08:00
Joel
3da854fe40 chore: some components upgrage to new ui (#6468) 2024-07-19 16:39:49 +08:00
Jyong
57729823a0 fix wrong method using (#6459) 2024-07-19 13:48:13 +08:00
sino
9e168f9d1c feat: support gpt-4o-mini for openrouter provider (#6447) 2024-07-19 13:09:41 +08:00
Weaxs
ea45496a74 update ernie models (#6454) 2024-07-19 13:08:39 +08:00
Sangmin Ahn
a5fcd91ba5 chore: make text generation timeout duration configurable (#6450) 2024-07-19 12:54:15 +08:00
Waffle
2ba05b041f refactor(myscale):Set the default value of the myscale vector db in DifyConfig. (#6441) 2024-07-19 10:57:45 +08:00
Richards Tu
8e49146a35 [EMERGENCY] Fix Anthropic header issue (#6445) 2024-07-19 07:38:15 +08:00
takatost
dad3fd2dc1 feat: add gpt-4o-mini (#6442) 2024-07-19 01:53:43 +08:00
yoyocircle
284ef52bba feat: passing the inputs values using difyChatbotConfig (#6376) 2024-07-18 21:54:16 +08:00
Jyong
e493ce9981 update clean embedding cache logic (#6434) 2024-07-18 20:25:28 +08:00
Weishan-0
7b45a5d452 fix: Unable to display images generated by Dall-E 3 (#6155) 2024-07-18 19:37:04 +08:00
ybalbert001
4a026fa352 Enhancement: add model provider - Amazon Sagemaker (#6255)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-07-18 19:32:31 +08:00
leoterry
dc847ba145 Fix the vector retrieval sorting issue (#6431)
Co-authored-by: weifj <“weifj@tuyuansu.com.cn”>
2024-07-18 19:25:41 +08:00
-LAN-
c0ec40e483 fix(api/core/tools/provider/builtin/spider/tools/scraper_crawler.yaml): Fix wrong placeholder config in scraper crawler tool. (#6432) 2024-07-18 19:23:18 +08:00
Carson
929c22a4e8 fix: tools edit modal schema edit issue (#6396) 2024-07-18 19:02:23 +08:00
themanforfree
ba181197c2 feat: api_key support for xinference (#6417)
Signed-off-by: themanforfree <themanforfree@gmail.com>
2024-07-18 18:58:46 +08:00
Songyawn
218930c897 fix tool icon get failed (#6375)
Co-authored-by: songyawen <songyawen@zkme.xyz>
2024-07-18 18:55:48 +08:00
Poorandy
c8f5dfcf17 refactor(rag): switch to dify_config. (#6410)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-07-18 18:40:36 +08:00
forrestsocool
27c8deb4ec feat: add custom tool timeout config to docker-compose.yaml and .env (#6419)
Signed-off-by: forrestsocool <sensensudo@gmail.com>
2024-07-18 18:40:17 +08:00
Joel
4ae4895ebe feat: add frontend unit test framework (#6426) 2024-07-18 17:35:10 +08:00
非法操作
afe95fa780 feat: support get workflow task execution status (#6411) 2024-07-18 15:06:14 +08:00
Kuizuo
166a40c66e fix: improve separation element in prompt log and TTS buttons in the operation (#6413) 2024-07-18 14:44:34 +08:00
William Espegren
588615b20e feat: Spider web scraper & crawler tool (#5725) 2024-07-18 14:29:33 +08:00
listeng
d5dca46854 feat: add a Tianditu tool (#6320)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-07-18 13:04:03 +08:00
Xiao Ley
23e5eeec00 feat: added custom secure_ascii to the json_process tool (#6401) 2024-07-18 08:43:14 +08:00
Harry Wang
287b42997d fix inconsistent label (#6404) 2024-07-18 08:37:16 +08:00
Masashi Tomooka
5236cb1888 fix: kill signal is not passed to the main process (#6159) 2024-07-18 07:50:54 +08:00
forrestlinfeng
3b5b548af3 Add Stepfun LLM Support (#6346) 2024-07-18 07:47:18 +08:00
Richards Tu
4782fb50c4 Support new Claude-3.5 Sonnet max token limit (#6335) 2024-07-18 07:47:06 +08:00
Jyong
f55876bcc5 fix web import url is too long (#6402) 2024-07-18 01:14:36 +08:00
Poorandy
8a80af39c9 refactor(models&tools): switch to dify_config in models and tools. (#6394)
Co-authored-by: Poorandy <andymonicamua1@gmail.com>
2024-07-17 22:26:18 +08:00
crazywoola
35f4a264d6 fix: default duration (#6393) 2024-07-17 21:19:04 +08:00
非法操作
6c798cbdaf fix: tool authorization setting panel not validate required fields (#6387) 2024-07-17 21:10:28 +08:00
Charlie.Wei
279f1c986f embed.js add esc exit and fix avoid infinite nesting (#6360)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2024-07-17 20:52:44 +08:00
Jyong
443e96777b update empty document caused delete exist collection (#6392) 2024-07-17 20:38:32 +08:00
faye1225
65bc4e0fc0 Fix issues related to search apps, notification duration, and loading icon on the explore page (#6374) 2024-07-17 20:24:31 +08:00
chenxu9741
a6dbd26f75 Add the API documentation for streaming TTS (Text-to-Speech) (#6382) 2024-07-17 19:44:16 +08:00
xielong
f3f052ba36 fix: rename model from ernie-4.0-8k-Latest to ernie-4.0-8k-latest (#6383) 2024-07-17 19:07:47 +08:00
Jyong
1bc90b992b Feat/optimize clean dataset logic (#6384) 2024-07-17 17:36:11 +08:00
-LAN-
fc37887a21 refactor(api/core/workflow/nodes/http_request): Remove mask_authorization_header because its alwary true. (#6379) 2024-07-17 16:52:14 +08:00
zxhlyh
984658f5e9 fix: workflow sync before export (#6380) 2024-07-17 16:51:48 +08:00
Lion
4ed1476531 fix: incorrect config key name (#6371)
Co-authored-by: LionYuYu <lyu@theknotww.com>
2024-07-17 15:52:51 +08:00
chenxu9741
ca69e1a2f5 Add multilingual support for TTS (Text-to-Speech) functionality. (#6369) 2024-07-17 14:41:29 +08:00
FamousMai
20f73cb756 fix: default model set wrong(#6327) (#6332)
Co-authored-by: maiyouming <maiyouming@yafex.cn>
2024-07-17 14:14:12 +08:00
Weaxs
4e2fba404d WebscraperTool bypass cloudflare site by cloudscraper (#6337) 2024-07-17 14:13:57 +08:00
Bowen Liang
7943f7f697 chore: fix legacy API usages of Query.get() by Session.get() in SqlAlchemy 2 (#6340) 2024-07-17 13:54:35 +08:00
Jyong
7c397f5722 update celery beat scheduler time to env (#6352) 2024-07-17 02:31:30 +08:00
189 changed files with 6595 additions and 596 deletions

View File

@@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
### 5. Visit dify in your browser

View File

@@ -77,7 +77,7 @@ Dify 依赖以下工具和库:
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
### 5. 在浏览器中访问 Dify

View File

@@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
### 5. ブラウザで dify にアクセスする

View File

@@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
default=False,
)
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description='',
default=None,
)

View File

@@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
@@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel):
MyScale configs
"""
MYSCALE_HOST: Optional[str] = Field(
MYSCALE_HOST: str = Field(
description='MyScale host',
default=None,
default='localhost',
)
MYSCALE_PORT: Optional[PositiveInt] = Field(
MYSCALE_PORT: PositiveInt = Field(
description='MyScale port',
default=8123,
)
MYSCALE_USER: Optional[str] = Field(
MYSCALE_USER: str = Field(
description='MyScale user',
default=None,
default='default',
)
MYSCALE_PASSWORD: Optional[str] = Field(
MYSCALE_PASSWORD: str = Field(
description='MyScale password',
default=None,
default='',
)
MYSCALE_DATABASE: Optional[str] = Field(
MYSCALE_DATABASE: str = Field(
description='MyScale database name',
default=None,
default='default',
)
MYSCALE_FTS_PARAMS: Optional[str] = Field(
MYSCALE_FTS_PARAMS: str = Field(
description='MyScale fts index parameters',
default=None,
default='',
)

View File

@@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
)
if last_id is not None:
last_segment = DocumentSegment.query.get(str(last_id))
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(
DocumentSegment.position > last_segment.position)

View File

@@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
if not TenantAccountRole.is_valid_role(new_role):
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
member = Account.query.get(str(member_id))
member = db.session.get(Account, str(member_id))
if not member:
abort(404)

View File

@@ -1,6 +1,6 @@
import logging
from flask_restful import Resource, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -21,14 +21,43 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class WorkflowRunApi(Resource):
workflow_run_fields = {
'id': fields.String,
'workflow_id': fields.String,
'status': fields.String,
'inputs': fields.Raw,
'outputs': fields.Raw,
'error': fields.String,
'total_steps': fields.Integer,
'total_tokens': fields.Integer,
'created_at': fields.DateTime,
'finished_at': fields.DateTime,
'elapsed_time': fields.Float,
}
@validate_app_token
@marshal_with(workflow_run_fields)
def get(self, app_model: App, workflow_id: str):
"""
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
return workflow_run
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
"""
@@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource):
}
api.add_resource(WorkflowRunApi, '/workflows/run')
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')

View File

@@ -64,6 +64,7 @@ User Input:
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
"The output must be an array in JSON format following the specified schema:\n"
"[\"question1\",\"question2\",\"question3\"]\n"
)

View File

@@ -103,7 +103,7 @@ class TokenBufferMemory:
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages:
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages

View File

@@ -27,9 +27,9 @@ parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
default: 8192
min: 1
max: 4096
max: 8192
- name: response_format
use_template: response_format
pricing:

View File

@@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if system:
extra_model_kwargs['system'] = system
# Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {}
if model == "claude-3-5-sonnet-20240620":
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if tools:
extra_model_kwargs['tools'] = [
self._transform_tool_prompt(tool) for tool in tools
@@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
)
@@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
model=model,
messages=prompt_message_dicts,
stream=stream,
extra_headers=extra_headers,
**model_parameters,
**extra_model_kwargs
)
@@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,

View File

@@ -1,6 +1,8 @@
- gpt-4
- gpt-4o
- gpt-4o-2024-05-13
- gpt-4o-mini
- gpt-4o-mini-2024-07-18
- gpt-4-turbo
- gpt-4-turbo-2024-04-09
- gpt-4-turbo-preview

View File

@@ -0,0 +1,44 @@
model: gpt-4o-mini-2024-07-18
label:
zh_Hans: gpt-4o-mini-2024-07-18
en_US: gpt-4o-mini-2024-07-18
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,44 @@
model: gpt-4o-mini
label:
zh_Hans: gpt-4o-mini
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.15'
output: '0.60'
unit: '0.000001'
currency: USD

View File

@@ -1,4 +1,5 @@
- openai/gpt-4o
- openai/gpt-4o-mini
- openai/gpt-4
- openai/gpt-4-32k
- openai/gpt-3.5-turbo

View File

@@ -0,0 +1,43 @@
model: openai/gpt-4o-mini
label:
en_US: gpt-4o-mini
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 16384
- name: response_format
label:
zh_Hans: 回复格式
en_US: response_format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: "0.15"
output: "0.60"
unit: "0.000001"
currency: USD

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

View File

@@ -0,0 +1,238 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
import boto3
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class SageMakerLargeLanguageModel(LargeLanguageModel):
"""
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# get model mode
model_mode = self.get_model_mode(model, credentials)
if not self.sagemaker_client:
access_key = credentials.get('access_key')
secret_key = credentials.get('secret_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(
{
"inputs": prompt_messages[0].content,
"parameters": { "stop" : stop},
"history" : []
}
),
ContentType="application/json",
)
assistant_text = response_model['Body'].read().decode('utf8')
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
usage = self._calc_response_usage(model, credentials, 0, 0)
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage
)
return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# get model mode
model_mode = self.get_model_mode(model)
try:
return 0
except Exception as e:
raise self._transform_invoke_error(e)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# get model mode
model_mode = self.get_model_mode(model)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature',
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
),
),
ParameterRule(
name='top_p',
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
),
ParameterRule(
name='max_tokens',
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=credentials.get('context_length', 2048),
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
]
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type == LLMMode.CHAT:
print(f"completion_type : {LLMMode.CHAT.value}")
if completion_type == LLMMode.COMPLETION:
print(f"completion_type : {LLMMode.COMPLETION.value}")
features = []
support_function_call = credentials.get('support_function_call', False)
if support_function_call:
features.append(ModelFeature.TOOL_CALL)
support_vision = credentials.get('support_vision', False)
if support_vision:
features.append(ModelFeature.VISION)
context_length = credentials.get('context_length', 2048)
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=features,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: context_length
},
parameter_rules=rules
)
return entity

View File

@@ -0,0 +1,190 @@
import json
import logging
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
logger = logging.getLogger(__name__)
class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
"""
sagemaker_client: Any = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
inputs = [query_input]*len(docs)
response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=rerank_endpoint,
Body=json.dumps(
{
"inputs": inputs,
"docs": docs
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
scores = json_obj['scores']
return scores if isinstance(scores, list) else [scores]
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) \
-> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
line = 0
try:
if len(docs) == 0:
return RerankResult(
model=model,
docs=docs
)
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
candidate_docs = []
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
for idx in range(len(scores)):
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
line = 3
rerank_documents = []
for idx, result in enumerate(candidate_docs):
rerank_document = RerankDocument(
index=idx,
text=result.get('content'),
score=result.get('score', -100.0)
)
if score_threshold is not None:
if rerank_document.score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)
return RerankResult(
model=model,
docs=rerank_documents
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
parameter_rules=[]
)
return entity

View File

@@ -0,0 +1,17 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

View File

@@ -0,0 +1,125 @@
provider: sagemaker
label:
zh_Hans: Sagemaker
en_US: Sagemaker
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
description:
en_US: Customized model on Sagemaker
zh_Hans: Sagemaker上的私有化部署的模型
background: "#ECE9E3"
help:
title:
en_US: How to deploy customized model on Sagemaker
zh_Hans: 如何在Sagemaker上的私有化部署的模型
url:
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
type: text-input
required: true
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: aws_access_key_id
required: false
label:
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
type: secret-input
placeholder:
en_US: Enter your Access Key
zh_Hans: 在此输入您的 Access Key
- variable: aws_secret_access_key
required: false
label:
en_US: Secret Access Key
zh_Hans: Secret Access Key
type: secret-input
placeholder:
en_US: Enter your Secret Access Key
zh_Hans: 在此输入您的 Secret Access Key
- variable: aws_region
required: false
label:
en_US: AWS Region
zh_Hans: AWS 地区
type: select
default: us-east-1
options:
- value: us-east-1
label:
en_US: US East (N. Virginia)
zh_Hans: 美国东部 (弗吉尼亚北部)
- value: us-west-2
label:
en_US: US West (Oregon)
zh_Hans: 美国西部 (俄勒冈州)
- value: ap-southeast-1
label:
en_US: Asia Pacific (Singapore)
zh_Hans: 亚太地区 (新加坡)
- value: ap-northeast-1
label:
en_US: Asia Pacific (Tokyo)
zh_Hans: 亚太地区 (东京)
- value: eu-central-1
label:
en_US: Europe (Frankfurt)
zh_Hans: 欧洲 (法兰克福)
- value: us-gov-west-1
label:
en_US: AWS GovCloud (US-West)
zh_Hans: AWS GovCloud (US-West)
- value: ap-southeast-2
label:
en_US: Asia Pacific (Sydney)
zh_Hans: 亚太地区 (悉尼)
- value: cn-north-1
label:
en_US: AWS Beijing (cn-north-1)
zh_Hans: 中国北京 (cn-north-1)
- value: cn-northwest-1
label:
en_US: AWS Ningxia (cn-northwest-1)
zh_Hans: 中国宁夏 (cn-northwest-1)

View File

@@ -0,0 +1,214 @@
import itertools
import json
import logging
import time
from typing import Any, Optional
import boto3
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
BATCH_SIZE = 20
CONTEXT_SIZE=8192
logger = logging.getLogger(__name__)
def batch_generator(generator, batch_size):
while True:
batch = list(itertools.islice(generator, batch_size))
if not batch:
break
yield batch
class SageMakerEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""
sagemaker_client: Any = None
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
response_model = sm_client.invoke_endpoint(
EndpointName=endpoint_name,
Body=json.dumps(
{
"inputs": content_list,
"parameters": {},
"is_query" : False,
"instruction" : ''
}
),
ContentType="application/json",
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
embeddings = json_obj['embeddings']
return embeddings
def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
try:
line = 1
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
line = 2
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
line = 3
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
all_embeddings = []
line = 4
for batch in batches:
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
all_embeddings.extend(embeddings)
line = 5
# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=0 # It's not SAAS API, usage is meaningless
)
line = 6
return TextEmbeddingResult(
embeddings=all_embeddings,
usage=usage,
model=model
)
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return 0
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
print("validate_credentials ok....")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
},
parameter_rules=[]
)
return entity

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -0,0 +1,6 @@
- step-1-8k
- step-1-32k
- step-1-128k
- step-1-256k
- step-1v-8k
- step-1v-32k

View File

@@ -0,0 +1,328 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get('function_calling_type') == 'tool_call'
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name='temperature',
use_template='temperature',
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
type=ParameterType.FLOAT,
),
ParameterRule(
name='max_tokens',
use_template='max_tokens',
default=512,
min=1,
max=int(credentials.get('max_tokens', 1024)),
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
type=ParameterType.INT,
),
ParameterRule(
name='top_p',
use_template='top_p',
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
type=ParameterType.FLOAT,
),
]
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
}.intersection(model_schema.features or []):
credentials['function_calling_type'] = 'tool_call'
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append({
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments
}
})
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ''
chunk_index = 0
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
-> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=message,
finish_reason=finish_reason,
usage=usage
)
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id='',
type='',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
break
if not chunk_json or len(chunk_json['choices']) == 0:
continue
choice = chunk_json['choices'][0]
finish_reason = chunk_json['choices'][0].get('finish_reason')
chunk_index += 1
if 'delta' in choice:
delta = choice['delta']
delta_content = delta.get('content')
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif 'text' in choice:
choice_text = choice.get('text', '')
if choice_text == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
)
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(
tool_calls=tools_calls,
content=""
),
)
)
yield create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason
)

View File

@@ -0,0 +1,25 @@
model: step-1-128k
label:
zh_Hans: step-1-128k
en_US: step-1-128k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 128000
pricing:
input: '0.04'
output: '0.20'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,25 @@
model: step-1-256k
label:
zh_Hans: step-1-256k
en_US: step-1-256k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 256000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 256000
pricing:
input: '0.095'
output: '0.300'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,28 @@
model: step-1-32k
label:
zh_Hans: step-1-32k
en_US: step-1-32k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,28 @@
model: step-1-8k
label:
zh_Hans: step-1-8k
en_US: step-1-8k
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 8000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8000
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,25 @@
model: step-1v-32k
label:
zh_Hans: step-1v-32k
en_US: step-1v-32k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32000
pricing:
input: '0.015'
output: '0.070'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,25 @@
model: step-1v-8k
label:
zh_Hans: step-1v-8k
en_US: step-1v-8k
model_type: llm
features:
- vision
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.005'
output: '0.020'
unit: '0.001'
currency: RMB

View File

@@ -0,0 +1,30 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class StepfunProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='step-1-8k',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@@ -0,0 +1,81 @@
provider: stepfun
label:
zh_Hans: 阶跃星辰
en_US: Stepfun
description:
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from stepfun
zh_Hans: 从 stepfun 获取 API Key
url:
en_US: https://platform.stepfun.com/interface-key
supported_model_types:
- llm
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '8192'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '8192'
type: text-input
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not supported
zh_Hans: 不支持
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call

View File

@@ -35,3 +35,4 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
deprecated: true

View File

@@ -1,4 +1,4 @@
model: ernie-4.0-8k-Latest
model: ernie-4.0-8k-latest
label:
en_US: Ernie-4.0-8K-Latest
model_type: llm

View File

@@ -0,0 +1,40 @@
model: ernie-4.0-turbo-8k
label:
en_US: Ernie-4.0-turbo-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.8
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 2048
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
en_US: Disable Search
type: boolean
help:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false

View File

@@ -28,3 +28,4 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -0,0 +1,30 @@
model: ernie-character-8k-0321
label:
en_US: ERNIE-Character-8K
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.1
max: 1.0
default: 0.95
- name: top_p
use_template: top_p
min: 0
max: 1.0
default: 0.7
- name: max_tokens
use_template: max_tokens
default: 1024
min: 2
max: 1024
- name: presence_penalty
use_template: presence_penalty
default: 1.0
min: 1.0
max: 2.0

View File

@@ -28,3 +28,4 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -28,3 +28,4 @@ parameter_rules:
default: 1.0
min: 1.0
max: 2.0
deprecated: true

View File

@@ -97,6 +97,7 @@ class BaiduAccessToken:
baidu_access_tokens_lock.release()
return token
class ErnieMessage:
class Role(Enum):
USER = 'user'
@@ -137,7 +138,9 @@ class ErnieBotModel:
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
}
@@ -149,7 +152,8 @@ class ErnieBotModel:
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k'
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview'
]

View File

@@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
api_key = credentials.get('api_key') or "abc"
client = OpenAI(
base_url=f'{credentials["server_url"]}/v1',
api_key='abc',
api_key=api_key,
max_retries=3,
timeout=60,
)

View File

@@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel):
docs=[]
)
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try:
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
response = handle.rerank(
documents=docs,
query=query,
top_n=top_n,
)
rerank_documents = []
for idx, result in enumerate(response['results']):
@@ -102,7 +110,7 @@ class XinferenceRerankModel(RerankModel):
if not isinstance(xinference_client, RESTfulRerankModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a rerank model')
self.invoke(
model=model,
credentials=credentials,

View File

@@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
}
def _speech2text_invoke(
self,
model: str,
credentials: dict,
self,
model: str,
credentials: dict,
file: IO[bytes],
language: Optional[str] = None,
prompt: Optional[str] = None,
@@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
:return: text for given audio file
"""
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
response = handle.transcriptions(
audio=file,
language = language,
prompt = prompt,
response_format = response_format,
temperature = temperature
)
try:
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
response = handle.transcriptions(
audio=file,
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature
)
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
return response["text"]

View File

@@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
server_url = credentials['server_url']
model_uid = credentials['model_uid']
api_key = credentials.get('api_key')
if server_url.endswith('/'):
server_url = server_url[:-1]
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try:
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
embeddings = handle.create_embedding(input=texts)
except RuntimeError as e:
raise InvokeServerUnavailableError(e)
raise InvokeServerUnavailableError(str(e))
"""
for convenience, the response json is like:
class Embedding(TypedDict):
@@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
try:
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
@@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url[:-1]
client = Client(base_url=server_url)
try:
handle = client.get_model(model_uid=model_uid)
except RuntimeError as e:
@@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
KeyError
]
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
@@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(

View File

@@ -46,3 +46,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入您的Model UID
en_US: Enter the model uid
- variable: api_key
label:
zh_Hans: API密钥
en_US: API key
type: text-input
required: false
placeholder:
zh_Hans: 在此输入您的API密钥
en_US: Enter the api key

View File

@@ -7,8 +7,8 @@ _import_err_msg = (
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from flask import current_app
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
class AnalyticdbVector(BaseVector):
_instance = None
_init = False
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, collection_name: str, config: AnalyticdbConfig):
# collection_name must be updated every time
self._collection_name = collection_name.lower()
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}"
)
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
def get_type(self) -> str:
return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension)
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
)
response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids)
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
)
documents.append(doc)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = (
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
)
documents.append(doc)
return documents
def delete(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest(
@@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
)
config = current_app.config
# TODO handle optional params
return AnalyticdbVector(
collection_name,
AnalyticdbConfig(
access_key_id=config.get("ANALYTICDB_KEY_ID"),
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
region_id=config.get("ANALYTICDB_REGION_ID"),
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
account=config.get("ANALYTICDB_ACCOUNT"),
account_password=config.get("ANALYTICDB_PASSWORD"),
namespace=config.get("ANALYTICDB_NAMESPACE"),
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
),
)
)

View File

@@ -3,9 +3,9 @@ from typing import Any, Optional
import chromadb
from chromadb import QueryResult, Settings
from flask import current_app
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -111,7 +111,8 @@ class ChromaVector(BaseVector):
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -133,15 +134,14 @@ class ChromaVectorFactory(AbstractVectorFactory):
}
dataset.index_struct = json.dumps(index_struct_dict)
config = current_app.config
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=config.get('CHROMA_HOST'),
port=int(config.get('CHROMA_PORT')),
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
host=dify_config.CHROMA_HOST,
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
),
)

View File

@@ -3,10 +3,10 @@ import logging
from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
config = current_app.config
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
host=dify_config.MILVUS_HOST,
port=dify_config.MILVUS_PORT,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
secure=dify_config.MILVUS_SECURE,
database=dify_config.MILVUS_DATABASE,
)
)

View File

@@ -5,9 +5,9 @@ from enum import Enum
from typing import Any
from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -156,15 +156,14 @@ class MyScaleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
config = current_app.config
return MyScaleVector(
collection_name=collection_name,
config=MyScaleConfig(
host=config.get("MYSCALE_HOST", "localhost"),
port=int(config.get("MYSCALE_PORT", 8123)),
user=config.get("MYSCALE_USER", "default"),
password=config.get("MYSCALE_PASSWORD", ""),
database=config.get("MYSCALE_DATABASE", "default"),
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
host=dify_config.MYSCALE_HOST,
port=dify_config.MYSCALE_PORT,
user=dify_config.MYSCALE_USER,
password=dify_config.MYSCALE_PASSWORD,
database=dify_config.MYSCALE_DATABASE,
fts_params=dify_config.MYSCALE_FTS_PARAMS,
),
)

View File

@@ -4,11 +4,11 @@ import ssl
from typing import Any, Optional
from uuid import uuid4
from flask import current_app
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
config = current_app.config
open_search_config = OpenSearchConfig(
host=config.get('OPENSEARCH_HOST'),
port=config.get('OPENSEARCH_PORT'),
user=config.get('OPENSEARCH_USER'),
password=config.get('OPENSEARCH_PASSWORD'),
secure=config.get('OPENSEARCH_SECURE'),
host=dify_config.OPENSEARCH_HOST,
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
)
return OpenSearchVector(

View File

@@ -6,9 +6,9 @@ from typing import Any
import numpy
import oracledb
from flask import current_app
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id varchar2(100)
id varchar2(100)
,text CLOB NOT NULL
,meta JSON
,embedding vector NOT NULL
)
)
"""
@@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
config = current_app.config
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
host=config.get("ORACLE_HOST"),
port=config.get("ORACLE_PORT"),
user=config.get("ORACLE_USER"),
password=config.get("ORACLE_PASSWORD"),
database=config.get("ORACLE_DATABASE"),
host=dify_config.ORACLE_HOST,
port=dify_config.ORACLE_PORT,
user=dify_config.ORACLE_USER,
password=dify_config.ORACLE_PASSWORD,
database=dify_config.ORACLE_DATABASE,
),
)

View File

@@ -3,7 +3,6 @@ import logging
from typing import Any
from uuid import UUID, uuid4
from flask import current_app
from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, model_validator
@@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
text TEXT NOT NULL,
meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL
) using heap;
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs"))
config = current_app.config
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
host=dify_config.PGVECTO_RS_HOST,
port=dify_config.PGVECTO_RS_PORT,
user=dify_config.PGVECTO_RS_USER,
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
),
dim=dim
)
)

View File

@@ -5,9 +5,9 @@ from typing import Any
import psycopg2.extras
import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
"""
@@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
config = current_app.config
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
host=dify_config.PGVECTOR_HOST,
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER,
password=dify_config.PGVECTOR_PASSWORD,
database=dify_config.PGVECTOR_DATABASE,
),
)
)

View File

@@ -19,6 +19,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -361,6 +362,8 @@ class QdrantVector(BaseVector):
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -444,11 +447,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=config.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
)
)

View File

@@ -2,7 +2,6 @@ import json
import uuid
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
@@ -19,6 +18,7 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
document TEXT NOT NULL,
metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
""")
session.execute(create_statement)
index_statement = sql_text(f"""
@@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
config = current_app.config
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
host=dify_config.RELYT_HOST,
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
),
group_id=dataset.id
)

View File

@@ -1,13 +1,13 @@
import json
from typing import Any, Optional
from flask import current_app
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
config = current_app.config
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
url=dify_config.TENCENT_VECTOR_DB_URL,
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
)
)
)

View File

@@ -3,12 +3,12 @@ import logging
from typing import Any
import sqlalchemy
from flask import current_app
from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
with Session(self._engine) as session:
select_statement = sql_text(
f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance
LIMIT {top_k}
) t WHERE distance < {distance};"""
@@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
config = current_app.config
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=config.get('TIDB_VECTOR_HOST'),
port=config.get('TIDB_VECTOR_PORT'),
user=config.get('TIDB_VECTOR_USER'),
password=config.get('TIDB_VECTOR_PASSWORD'),
database=config.get('TIDB_VECTOR_DATABASE'),
program_name=config.get('APPLICATION_NAME'),
host=dify_config.TIDB_VECTOR_HOST,
port=dify_config.TIDB_VECTOR_PORT,
user=dify_config.TIDB_VECTOR_USER,
password=dify_config.TIDB_VECTOR_PASSWORD,
database=dify_config.TIDB_VECTOR_DATABASE,
program_name=dify_config.APPLICATION_NAME,
),
)
)

View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any
from flask import current_app
from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -37,8 +36,7 @@ class Vector:
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = current_app.config
vector_type = config.get('VECTOR_STORE')
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']

View File

@@ -4,9 +4,9 @@ from typing import Any, Optional
import requests
import weaviate
from flask import current_app
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -216,7 +216,8 @@ class WeaviateVector(BaseVector):
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -281,9 +282,9 @@ class WeaviateVectorFactory(AbstractVectorFactory):
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
api_key=current_app.config.get('WEAVIATE_API_KEY'),
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
),
attributes=attributes
)

View File

@@ -5,8 +5,8 @@ from typing import Union
from urllib.parse import unquote
import requests
from flask import current_app
from configs import dify_config
from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
@@ -94,9 +94,9 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == 'Unstructured':
if file_extension == '.xlsx' or file_extension == '.xls':
extractor = ExcelExtractor(file_path)

View File

@@ -3,8 +3,8 @@ import logging
from typing import Any, Optional
import requests
from flask import current_app
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
@@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor):
self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id)
if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment "

View File

@@ -8,8 +8,8 @@ from urllib.parse import urlparse
import requests
from docx import Document as DocxDocument
from flask import current_app
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
@@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor):
storage.save(file_key, rel.target_part.blob)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=self.tenant_id,
storage_type=config['STORAGE_TYPE'],
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=0,
@@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor):
db.session.add(upload_file)
db.session.commit()
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
return image_map

View File

@@ -2,8 +2,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from flask import current_app
from configs import dify_config
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
@@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC):
# The user-defined segmentation rule
rules = processing_rule['rules']
segmentation = rules["segmentation"]
max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")

View File

@@ -30,3 +30,4 @@
- feishu
- feishu_base
- slack
- tianditu

View File

@@ -1,5 +1,5 @@
import base64
import random
from base64 import b64decode
from typing import Any, Union
from openai import OpenAI
@@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value))
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(blob=blob_image,
meta={'mime_type': mime_type},
save_as=self.VARIABLE_KEY.IMAGE.value)
result.append(blob_message)
return result
@staticmethod
def _decode_image(base64_image: str) -> tuple[str, bytes]:
"""
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
it assumes 'image/png' as the default.
:param base64_image: Base64 encoded image string
:return: A tuple containing the MIME type and the decoded image bytes
"""
if DallE3Tool._is_plain_base64(base64_image):
return 'image/png', base64.b64decode(base64_image)
else:
return DallE3Tool._extract_mime_and_data(base64_image)
@staticmethod
def _is_plain_base64(encoded_str: str) -> bool:
"""
Check if the given encoded string is plain base64 without a MIME type prefix.
:param encoded_str: Base64 encoded image string
:return: True if the string is plain base64, False otherwise
"""
return not encoded_str.startswith('data:image')
@staticmethod
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
"""
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
:param encoded_str: Base64 encoded image string with MIME type prefix
:return: A tuple containing the MIME type and the decoded image bytes
"""
mime_type = encoded_str.split(';')[0].split(':')[1]
image_data_base64 = encoded_str.split(',')[1]
decoded_data = base64.b64decode(image_data_base64)
return mime_type, decoded_data
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256"><rect width="256" height="256" fill="none"/><rect x="32" y="48" width="192" height="160" rx="8" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><circle cx="156" cy="100" r="12"/><path d="M147.31,164,173,138.34a8,8,0,0,1,11.31,0L224,178.06" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><path d="M32,168.69l54.34-54.35a8,8,0,0,1,11.32,0L191.31,208" fill="none" stroke="#1553ed" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/></svg>

After

Width:  |  Height:  |  Size: 617 B

View File

@@ -0,0 +1,22 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.getimgai.tools.text2image import Text2ImageTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class GetImgAIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
try:
# Example validation using the text2image tool
Text2ImageTool().fork_tool_runtime(
runtime={"credentials": credentials}
).invoke(
user_id='',
tool_parameters={
"prompt": "A fire egg",
"response_format": "url",
"style": "photorealism",
}
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -0,0 +1,29 @@
identity:
author: Matri Qi
name: getimgai
label:
en_US: getimg.ai
zh_CN: getimg.ai
description:
en_US: GetImg API integration for image generation and scraping.
icon: icon.svg
tags:
- image
credentials_for_provider:
getimg_api_key:
type: secret-input
required: true
label:
en_US: getimg.ai API Key
placeholder:
en_US: Please input your getimg.ai API key
help:
en_US: Get your getimg.ai API key from your getimg.ai account settings. If you are using a self-hosted version, you may enter any key at your convenience.
url: https://dashboard.getimg.ai/api-keys
base_url:
type: text-input
required: false
label:
en_US: getimg.ai server's Base URL
placeholder:
en_US: https://api.getimg.ai/v1

View File

@@ -0,0 +1,59 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
import requests
from requests.exceptions import HTTPError
logger = logging.getLogger(__name__)
class GetImgAIApp:
def __init__(self, api_key: str | None = None, base_url: str | None = None):
self.api_key = api_key
self.base_url = base_url or 'https://api.getimg.ai/v1'
if not self.api_key:
raise ValueError("API key is required")
def _prepare_headers(self):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
return headers
def _request(
self,
method: str,
url: str,
data: Mapping[str, Any] | None = None,
headers: Mapping[str, str] | None = None,
retries: int = 3,
backoff_factor: float = 0.3,
) -> Mapping[str, Any] | None:
for i in range(retries):
try:
response = requests.request(method, url, json=data, headers=headers)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
time.sleep(backoff_factor * (2 ** i))
else:
raise
return None
def text2image(
self, mode: str, **kwargs
):
data = kwargs['params']
if not data.get('prompt'):
raise ValueError("Prompt is required")
endpoint = f'{self.base_url}/{mode}/text-to-image'
headers = self._prepare_headers()
logger.debug(f"Send request to {endpoint=} body={data}")
response = self._request('POST', endpoint, data, headers)
if response is None:
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
return response

View File

@@ -0,0 +1,39 @@
import json
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.getimgai.getimgai_appx import GetImgAIApp
from core.tools.tool.builtin_tool import BuiltinTool
class Text2ImageTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
options = {
'style': tool_parameters.get('style'),
'prompt': tool_parameters.get('prompt'),
'aspect_ratio': tool_parameters.get('aspect_ratio'),
'output_format': tool_parameters.get('output_format', 'jpeg'),
'response_format': tool_parameters.get('response_format', 'url'),
'width': tool_parameters.get('width'),
'height': tool_parameters.get('height'),
'steps': tool_parameters.get('steps'),
'negative_prompt': tool_parameters.get('negative_prompt'),
'prompt_2': tool_parameters.get('prompt_2'),
}
options = {k: v for k, v in options.items() if v}
text2image_result = app.text2image(
mode=tool_parameters.get('mode', 'essential-v2'),
params=options,
wait=True
)
if not isinstance(text2image_result, str):
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)
if not text2image_result:
return self.create_text_message("getimg.ai request failed.")
return self.create_text_message(text2image_result)

View File

@@ -0,0 +1,167 @@
identity:
name: text2image
author: Matri Qi
label:
en_US: text2image
icon: icon.svg
description:
human:
en_US: Generate image via getimg.ai.
llm: This tool is used to generate image from prompt or image via https://getimg.ai.
parameters:
- name: prompt
type: string
required: true
label:
en_US: prompt
human_description:
en_US: The text prompt used to generate the image. The getimg.aier will generate an image based on this prompt.
llm_description: this prompt text will be used to generate image.
form: llm
- name: mode
type: select
required: false
label:
en_US: mode
human_description:
en_US: The getimg.ai mode to use. The mode determines the endpoint used to generate the image.
form: form
options:
- value: "essential-v2"
label:
en_US: essential-v2
- value: stable-diffusion-xl
label:
en_US: stable-diffusion-xl
- value: stable-diffusion
label:
en_US: stable-diffusion
- value: latent-consistency
label:
en_US: latent-consistency
- name: style
type: select
required: false
label:
en_US: style
human_description:
en_US: The style preset to use. The style preset guides the generation towards a particular style. It's just efficient for `Essential V2` mode.
form: form
options:
- value: photorealism
label:
en_US: photorealism
- value: anime
label:
en_US: anime
- value: art
label:
en_US: art
- name: aspect_ratio
type: select
required: false
label:
en_US: "aspect ratio"
human_description:
en_US: The aspect ratio of the generated image. It's just efficient for `Essential V2` mode.
form: form
options:
- value: "1:1"
label:
en_US: "1:1"
- value: "4:5"
label:
en_US: "4:5"
- value: "5:4"
label:
en_US: "5:4"
- value: "2:3"
label:
en_US: "2:3"
- value: "3:2"
label:
en_US: "3:2"
- value: "4:7"
label:
en_US: "4:7"
- value: "7:4"
label:
en_US: "7:4"
- name: output_format
type: select
required: false
label:
en_US: "output format"
human_description:
en_US: The file format of the generated image.
form: form
options:
- value: jpeg
label:
en_US: jpeg
- value: png
label:
en_US: png
- name: response_format
type: select
required: false
label:
en_US: "response format"
human_description:
en_US: The format in which the generated images are returned. Must be one of url or b64. URLs are only valid for 1 hour after the image has been generated.
form: form
options:
- value: url
label:
en_US: url
- value: b64
label:
en_US: b64
- name: model
type: string
required: false
label:
en_US: model
human_description:
en_US: Model ID supported by this pipeline and family. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form
- name: negative_prompt
type: string
required: false
label:
en_US: negative prompt
human_description:
en_US: Text input that will not guide the image generation. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form
- name: prompt_2
type: string
required: false
label:
en_US: prompt2
human_description:
en_US: Prompt sent to second tokenizer and text encoder. If not defined, prompt is used in both text-encoders. It's just efficient for `Stable Diffusion XL` mode.
form: form
- name: width
type: number
required: false
label:
en_US: width
human_description:
en_US: he width of the generated image in pixels. Width needs to be multiple of 64.
form: form
- name: height
type: number
required: false
label:
en_US: height
human_description:
en_US: he height of the generated image in pixels. Height needs to be multiple of 64.
form: form
- name: steps
type: number
required: false
label:
en_US: steps
human_description:
en_US: The number of denoising steps. More steps usually can produce higher quality images, but take more time to generate. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
form: form

View File

@@ -19,28 +19,29 @@ class JSONDeleteTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# Get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._delete(content, query)
result = self._delete(content, query, ensure_ascii)
return self.create_text_message(str(result))
except Exception as e:
return self.create_text_message(f'Failed to delete JSON content: {str(e)}')
def _delete(self, origin_json: str, query: str) -> str:
def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(origin_json)
expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $
matches = expr.find(input_data)
if not matches:
return json.dumps(input_data, ensure_ascii=True) # No changes if no matches found
return json.dumps(input_data, ensure_ascii=ensure_ascii) # No changes if no matches found
for match in matches:
if isinstance(match.context.value, dict):
# Delete key from dictionary
@@ -53,7 +54,7 @@ class JSONDeleteTool(BuiltinTool):
parent = match.context.parent
if parent:
del parent.value[match.path.fields[-1]]
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
raise Exception(f"Delete operation failed: {str(e)}")
raise Exception(f"Delete operation failed: {str(e)}")

View File

@@ -38,3 +38,15 @@ parameters:
pt_BR: JSONPath query to locate the element to delete
llm_description: JSONPath query to locate the element to delete
form: llm
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@@ -19,31 +19,31 @@ class JSONParseTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
# get new value
new_value = tool_parameters.get('new_value', '')
if not new_value:
return self.create_text_message('Invalid parameter new_value')
# get insert position
index = tool_parameters.get('index')
# get create path
create_path = tool_parameters.get('create_path', False)
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._insert(content, query, new_value, index, create_path)
result = self._insert(content, query, new_value, ensure_ascii, index, create_path)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to insert JSON content')
def _insert(self, origin_json, query, new_value, index=None, create_path=False):
def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False):
try:
input_data = json.loads(origin_json)
expr = parse(query)
@@ -51,9 +51,9 @@ class JSONParseTool(BuiltinTool):
new_value = json.loads(new_value)
except json.JSONDecodeError:
new_value = new_value
matches = expr.find(input_data)
if not matches and create_path:
# create new path
path_parts = query.strip('$').strip('.').split('.')
@@ -91,7 +91,7 @@ class JSONParseTool(BuiltinTool):
else:
# replace old value with new value
match.full_path.update(input_data, new_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
return str(e)

View File

@@ -75,3 +75,15 @@ parameters:
zh_Hans:
pt_BR: "No"
form: form
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@@ -19,33 +19,34 @@ class JSONParseTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get json filter
json_filter = tool_parameters.get('json_filter', '')
if not json_filter:
return self.create_text_message('Invalid parameter json_filter')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
result = self._extract(content, json_filter)
result = self._extract(content, json_filter, ensure_ascii)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to extract JSON content')
# Extract data from JSON content
def _extract(self, content: str, json_filter: str) -> str:
def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(json_filter)
result = [match.value for match in expr.find(input_data)]
if len(result) == 1:
result = result[0]
if isinstance(result, dict | list):
return json.dumps(result, ensure_ascii=True)
return json.dumps(result, ensure_ascii=ensure_ascii)
elif isinstance(result, str | int | float | bool) or result is None:
return str(result)
else:
return repr(result)
except Exception as e:
return str(e)
return str(e)

View File

@@ -38,3 +38,15 @@ parameters:
pt_BR: JSON fields to be parsed
llm_description: JSON fields to be parsed
form: llm
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@@ -19,61 +19,62 @@ class JSONReplaceTool(BuiltinTool):
content = tool_parameters.get('content', '')
if not content:
return self.create_text_message('Invalid parameter content')
# get query
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Invalid parameter query')
# get replace value
replace_value = tool_parameters.get('replace_value', '')
if not replace_value:
return self.create_text_message('Invalid parameter replace_value')
# get replace model
replace_model = tool_parameters.get('replace_model', '')
if not replace_model:
return self.create_text_message('Invalid parameter replace_model')
ensure_ascii = tool_parameters.get('ensure_ascii', True)
try:
if replace_model == 'pattern':
# get replace pattern
replace_pattern = tool_parameters.get('replace_pattern', '')
if not replace_pattern:
return self.create_text_message('Invalid parameter replace_pattern')
result = self._replace_pattern(content, query, replace_pattern, replace_value)
result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii)
elif replace_model == 'key':
result = self._replace_key(content, query, replace_value)
result = self._replace_key(content, query, replace_value, ensure_ascii)
elif replace_model == 'value':
result = self._replace_value(content, query, replace_value)
result = self._replace_value(content, query, replace_value, ensure_ascii)
return self.create_text_message(str(result))
except Exception:
return self.create_text_message('Failed to replace JSON content')
# Replace pattern
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str) -> str:
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
new_value = match.value.replace(replace_pattern, replace_value)
match.full_path.update(input_data, new_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
# Replace key
def _replace_key(self, content: str, query: str, replace_value: str) -> str:
def _replace_key(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
parent = match.context.value
if isinstance(parent, dict):
@@ -86,21 +87,21 @@ class JSONReplaceTool(BuiltinTool):
if isinstance(item, dict) and old_key in item:
value = item.pop(old_key)
item[replace_value] = value
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
# Replace value
def _replace_value(self, content: str, query: str, replace_value: str) -> str:
def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
try:
input_data = json.loads(content)
expr = parse(query)
matches = expr.find(input_data)
for match in matches:
match.full_path.update(input_data, replace_value)
return json.dumps(input_data, ensure_ascii=True)
return json.dumps(input_data, ensure_ascii=ensure_ascii)
except Exception as e:
return str(e)
return str(e)

View File

@@ -93,3 +93,15 @@ parameters:
zh_Hans: 字符串替换
pt_BR: replace string
form: form
- name: ensure_ascii
type: boolean
default: true
label:
en_US: Ensure ASCII
zh_Hans: 确保 ASCII
pt_BR: Ensure ASCII
human_description:
en_US: Ensure the JSON output is ASCII encoded
zh_Hans: 确保输出的 JSON 是 ASCII 编码
pt_BR: Ensure the JSON output is ASCII encoded
form: form

View File

@@ -0,0 +1 @@
<svg height="30" width="30" viewBox="0 0 36 34" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" class="fill-accent-foreground transition-all group-hover:scale-110"><title>Spider v1 Logo</title><path fill-rule="evenodd" clip-rule="evenodd" d="M9.13883 7.06589V0.164429L13.0938 0.164429V6.175L14.5178 7.4346C15.577 6.68656 16.7337 6.27495 17.945 6.27495C19.1731 6.27495 20.3451 6.69807 21.4163 7.46593L22.8757 6.175V0.164429L26.8307 0.164429V7.06589V7.95679L26.1634 8.54706L24.0775 10.3922C24.3436 10.8108 24.5958 11.2563 24.8327 11.7262L26.0467 11.4215L28.6971 8.08749L31.793 10.5487L28.7257 14.407L28.3089 14.9313L27.6592 15.0944L26.2418 15.4502C26.3124 15.7082 26.3793 15.9701 26.4422 16.2355L28.653 16.6566L29.092 16.7402L29.4524 17.0045L35.3849 21.355L33.0461 24.5444L27.474 20.4581L27.0719 20.3816C27.1214 21.0613 27.147 21.7543 27.147 22.4577C27.147 22.5398 27.1466 22.6214 27.1459 22.7024L29.5889 23.7911L30.3219 24.1177L30.62 24.8629L33.6873 32.5312L30.0152 34L27.246 27.0769L26.7298 26.8469C25.5612 32.2432 22.0701 33.8808 17.945 33.8808C13.8382 33.8808 10.3598 32.2577 9.17593 26.9185L8.82034 27.0769L6.05109 34L2.37897 32.5312L5.44629 24.8629L5.74435 24.1177L6.47743 23.7911L8.74487 22.7806C8.74366 22.6739 8.74305 22.5663 8.74305 22.4577C8.74305 21.7616 8.76804 21.0758 8.81654 20.4028L8.52606 20.4581L2.95395 24.5444L0.615112 21.355L6.54761 17.0045L6.908 16.7402L7.34701 16.6566L9.44264 16.2575C9.50917 15.9756 9.5801 15.6978 9.65528 15.4242L8.34123 15.0944L7.69155 14.9313L7.27471 14.407L4.20739 10.5487L7.30328 8.08749L9.95376 11.4215L11.0697 11.7016C11.3115 11.2239 11.5692 10.7716 11.8412 10.3473L9.80612 8.54706L9.13883 7.95679V7.06589Z"></path></svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -0,0 +1,14 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.spider.spiderApp import Spider
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class SpiderProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
app = Spider(api_key=credentials["spider_api_key"])
app.scrape_url(url="https://spider.cloud")
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -0,0 +1,27 @@
identity:
author: William Espegren
name: spider
label:
en_US: Spider
zh_CN: Spider
description:
en_US: Spider API integration, returning LLM-ready data by scraping & crawling websites.
zh_CN: Spider API 集成,通过爬取和抓取网站返回 LLM-ready 数据。
icon: icon.svg
tags:
- search
- utilities
credentials_for_provider:
spider_api_key:
type: secret-input
required: true
label:
en_US: Spider API Key
zh_CN: Spider API 密钥
placeholder:
en_US: Please input your Spider API key
zh_CN: 请输入您的 Spider API 密钥
help:
en_US: Get your Spider API key from your Spider dashboard
zh_CN: 从您的 Spider 仪表板中获取 Spider API 密钥。
url: https://spider.cloud/

View File

@@ -0,0 +1,237 @@
import os
from typing import Literal, Optional, TypedDict
import requests
class RequestParamsDict(TypedDict, total=False):
url: Optional[str]
request: Optional[Literal["http", "chrome", "smart"]]
limit: Optional[int]
return_format: Optional[Literal["raw", "markdown", "html2text", "text", "bytes"]]
tld: Optional[bool]
depth: Optional[int]
cache: Optional[bool]
budget: Optional[dict[str, int]]
locale: Optional[str]
cookies: Optional[str]
stealth: Optional[bool]
headers: Optional[dict[str, str]]
anti_bot: Optional[bool]
metadata: Optional[bool]
viewport: Optional[dict[str, int]]
encoding: Optional[str]
subdomains: Optional[bool]
user_agent: Optional[str]
store_data: Optional[bool]
gpt_config: Optional[list[str]]
fingerprint: Optional[bool]
storageless: Optional[bool]
readability: Optional[bool]
proxy_enabled: Optional[bool]
respect_robots: Optional[bool]
query_selector: Optional[str]
full_resources: Optional[bool]
request_timeout: Optional[int]
run_in_background: Optional[bool]
skip_config_checks: Optional[bool]
class Spider:
def __init__(self, api_key: Optional[str] = None):
"""
Initialize the Spider with an API key.
:param api_key: A string of the API key for Spider. Defaults to the SPIDER_API_KEY environment variable.
:raises ValueError: If no API key is provided.
"""
self.api_key = api_key or os.getenv("SPIDER_API_KEY")
if self.api_key is None:
raise ValueError("No API key provided")
def api_post(
self,
endpoint: str,
data: dict,
stream: bool,
content_type: str = "application/json",
):
"""
Send a POST request to the specified API endpoint.
:param endpoint: The API endpoint to which the POST request is sent.
:param data: The data (dictionary) to be sent in the POST request.
:param stream: Boolean indicating if the response should be streamed.
:return: The JSON response or the raw response stream if stream is True.
"""
headers = self._prepare_headers(content_type)
response = self._post_request(
f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream
)
if stream:
return response
elif response.status_code == 200:
return response.json()
else:
self._handle_error(response, f"post to {endpoint}")
def api_get(
self, endpoint: str, stream: bool, content_type: str = "application/json"
):
"""
Send a GET request to the specified endpoint.
:param endpoint: The API endpoint from which to retrieve data.
:return: The JSON decoded response.
"""
headers = self._prepare_headers(content_type)
response = self._get_request(
f"https://api.spider.cloud/v1/{endpoint}", headers, stream
)
if response.status_code == 200:
return response.json()
else:
self._handle_error(response, f"get from {endpoint}")
def get_credits(self):
"""
Retrieve the account's remaining credits.
:return: JSON response containing the number of credits left.
"""
return self.api_get("credits", stream=False)
def scrape_url(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Scrape data from the specified URL.
:param url: The URL from which to scrape data.
:param params: Optional dictionary of additional parameters for the scrape request.
:return: JSON response containing the scraping results.
"""
# Add { "return_format": "markdown" } to the params if not already present
if "return_format" not in params:
params["return_format"] = "markdown"
# Set limit to 1
params["limit"] = 1
return self.api_post(
"crawl", {"url": url, **(params or {})}, stream, content_type
)
def crawl_url(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Start crawling at the specified URL.
:param url: The URL to begin crawling.
:param params: Optional dictionary with additional parameters to customize the crawl.
:param stream: Boolean indicating if the response should be streamed. Defaults to False.
:return: JSON response or the raw response stream if streaming enabled.
"""
# Add { "return_format": "markdown" } to the params if not already present
if "return_format" not in params:
params["return_format"] = "markdown"
return self.api_post(
"crawl", {"url": url, **(params or {})}, stream, content_type
)
def links(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Retrieve links from the specified URL.
:param url: The URL from which to extract links.
:param params: Optional parameters for the link retrieval request.
:return: JSON response containing the links.
"""
return self.api_post(
"links", {"url": url, **(params or {})}, stream, content_type
)
def extract_contacts(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Extract contact information from the specified URL.
:param url: The URL from which to extract contact information.
:param params: Optional parameters for the contact extraction.
:return: JSON response containing extracted contact details.
"""
return self.api_post(
"pipeline/extract-contacts",
{"url": url, **(params or {})},
stream,
content_type,
)
def label(
self,
url: str,
params: Optional[RequestParamsDict] = None,
stream: bool = False,
content_type: str = "application/json",
):
"""
Apply labeling to data extracted from the specified URL.
:param url: The URL to label data from.
:param params: Optional parameters to guide the labeling process.
:return: JSON response with labeled data.
"""
return self.api_post(
"pipeline/label", {"url": url, **(params or {})}, stream, content_type
)
def _prepare_headers(self, content_type: str = "application/json"):
return {
"Content-Type": content_type,
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "Spider-Client/0.0.27",
}
def _post_request(self, url: str, data, headers, stream=False):
return requests.post(url, headers=headers, json=data, stream=stream)
def _get_request(self, url: str, headers, stream=False):
return requests.get(url, headers=headers, stream=stream)
def _delete_request(self, url: str, headers, stream=False):
return requests.delete(url, headers=headers, stream=stream)
def _handle_error(self, response, action):
if response.status_code in [402, 409, 500]:
error_message = response.json().get("error", "Unknown error occurred")
raise Exception(
f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}"
)
else:
raise Exception(
f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}"
)

View File

@@ -0,0 +1,47 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.spider.spiderApp import Spider
from core.tools.tool.builtin_tool import BuiltinTool
class ScrapeTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
# initialize the app object with the api key
app = Spider(api_key=self.runtime.credentials['spider_api_key'])
url = tool_parameters['url']
mode = tool_parameters['mode']
options = {
'limit': tool_parameters.get('limit', 0),
'depth': tool_parameters.get('depth', 0),
'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [],
'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [],
'readability': tool_parameters.get('readability', False),
}
result = ""
try:
if mode == 'scrape':
scrape_result = app.scrape_url(
url=url,
params=options,
)
for i in scrape_result:
result += "URL: " + i.get('url', '') + "\n"
result += "CONTENT: " + i.get('content', '') + "\n\n"
elif mode == 'crawl':
crawl_result = app.crawl_url(
url=tool_parameters['url'],
params=options,
)
for i in crawl_result:
result += "URL: " + i.get('url', '') + "\n"
result += "CONTENT: " + i.get('content', '') + "\n\n"
except Exception as e:
return self.create_text_message("An error occured", str(e))
return self.create_text_message(result)

View File

@@ -0,0 +1,102 @@
identity:
name: scraper_crawler
author: William Espegren
label:
en_US: Web Scraper & Crawler
zh_Hans: 网页抓取与爬虫
description:
human:
en_US: A tool for scraping & crawling webpages. Input should be a url.
zh_Hans: 用于抓取和爬取网页的工具。输入应该是一个网址。
llm: A tool for scraping & crawling webpages. Input should be a url.
parameters:
- name: url
type: string
required: true
label:
en_US: URL
zh_Hans: 网址
human_description:
en_US: url to be scraped or crawled
zh_Hans: 要抓取或爬取的网址
llm_description: url to either be scraped or crawled
form: llm
- name: mode
type: select
required: true
options:
- value: scrape
label:
en_US: scrape
zh_Hans: 抓取
- value: crawl
label:
en_US: crawl
zh_Hans: 爬取
default: crawl
label:
en_US: Mode
zh_Hans: 模式
human_description:
en_US: used for selecting to either scrape the website or crawl the entire website following subpages
zh_Hans: 用于选择抓取网站或爬取整个网站及其子页面
form: form
- name: limit
type: number
required: false
label:
en_US: maximum number of pages to crawl
zh_Hans: 最大爬取页面数
human_description:
en_US: specify the maximum number of pages to crawl per website. the crawler will stop after reaching this limit.
zh_Hans: 指定每个网站要爬取的最大页面数。爬虫将在达到此限制后停止。
form: form
min: 0
default: 0
- name: depth
type: number
required: false
label:
en_US: maximum depth of pages to crawl
zh_Hans: 最大爬取深度
human_description:
en_US: the crawl limit for maximum depth.
zh_Hans: 最大爬取深度的限制。
form: form
min: 0
default: 0
- name: blacklist
type: string
required: false
label:
en_US: url patterns to exclude
zh_Hans: 要排除的URL模式
human_description:
en_US: blacklist a set of paths that you do not want to crawl. you can use regex patterns to help with the list.
zh_Hans: 指定一组不想爬取的路径。您可以使用正则表达式模式来帮助定义列表。
placeholder:
en_US: /blog/*, /about
form: form
- name: whitelist
type: string
required: false
label:
en_US: URL patterns to include
zh_Hans: 要包含的URL模式
human_description:
en_US: Whitelist a set of paths that you want to crawl, ignoring all other routes that do not match the patterns. You can use regex patterns to help with the list.
zh_Hans: 指定一组要爬取的路径,忽略所有不匹配模式的其他路由。您可以使用正则表达式模式来帮助定义列表。
placeholder:
en_US: /blog/*, /about
form: form
- name: readability
type: boolean
required: false
label:
en_US: Pre-process the content for LLM usage
zh_Hans: 仅返回页面的主要内容
human_description:
en_US: Use Mozilla's readability to pre-process the content for reading. This may drastically improve the content for LLM usage.
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
form: form
default: false

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 23 KiB

View File

@@ -0,0 +1,21 @@
from typing import Any
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.tianditu.tools.poisearch import PoiSearchTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
class TiandituProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
PoiSearchTool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(user_id='',
tool_parameters={
'content': '北京',
'specify': '156110000',
})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -0,0 +1,32 @@
identity:
author: Listeng
name: tianditu
label:
en_US: Tianditu
zh_Hans: 天地图
pt_BR: Tianditu
description:
en_US: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
zh_Hans: 天地图工具可以调用天地图的接口,实现中国区域内的地名搜索、地理编码、静态地图等功能。
pt_BR: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
icon: icon.svg
tags:
- utilities
- travel
credentials_for_provider:
tianditu_api_key:
type: secret-input
required: true
label:
en_US: Tianditu API Key
zh_Hans: 天地图Key
pt_BR: Tianditu API key
placeholder:
en_US: Please input your Tianditu API key
zh_Hans: 请输入你的天地图Key
pt_BR: Please input your Tianditu API key
help:
en_US: Get your Tianditu API key from Tianditu
zh_Hans: 获取您的天地图Key
pt_BR: Get your Tianditu API key from Tianditu
url: http://lbs.tianditu.gov.cn/home.html

View File

@@ -0,0 +1,33 @@
import json
from typing import Any, Union
import requests
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class GeocoderTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
base_url = 'http://api.tianditu.gov.cn/geocoder'
keyword = tool_parameters.get('keyword', '')
if not keyword:
return self.create_text_message('Invalid parameter keyword')
tk = self.runtime.credentials['tianditu_api_key']
params = {
'keyWord': keyword,
}
result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json()
return self.create_json_message(result)

Some files were not shown because too many files have changed in this diff Show More