Compare commits
56 Commits
feat/updat
...
fix/trace_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c12596af48 | ||
|
|
27e08a8e2e | ||
|
|
49ef9ef225 | ||
|
|
c013086e64 | ||
|
|
48f872a68c | ||
|
|
4f9f175f25 | ||
|
|
47e5dc218a | ||
|
|
90372932fe | ||
|
|
0bb2b285da | ||
|
|
3da854fe40 | ||
|
|
57729823a0 | ||
|
|
9e168f9d1c | ||
|
|
ea45496a74 | ||
|
|
a5fcd91ba5 | ||
|
|
2ba05b041f | ||
|
|
8e49146a35 | ||
|
|
dad3fd2dc1 | ||
|
|
284ef52bba | ||
|
|
e493ce9981 | ||
|
|
7b45a5d452 | ||
|
|
4a026fa352 | ||
|
|
dc847ba145 | ||
|
|
c0ec40e483 | ||
|
|
929c22a4e8 | ||
|
|
ba181197c2 | ||
|
|
218930c897 | ||
|
|
c8f5dfcf17 | ||
|
|
27c8deb4ec | ||
|
|
4ae4895ebe | ||
|
|
afe95fa780 | ||
|
|
166a40c66e | ||
|
|
588615b20e | ||
|
|
d5dca46854 | ||
|
|
23e5eeec00 | ||
|
|
287b42997d | ||
|
|
5236cb1888 | ||
|
|
3b5b548af3 | ||
|
|
4782fb50c4 | ||
|
|
f55876bcc5 | ||
|
|
8a80af39c9 | ||
|
|
35f4a264d6 | ||
|
|
6c798cbdaf | ||
|
|
279f1c986f | ||
|
|
443e96777b | ||
|
|
65bc4e0fc0 | ||
|
|
a6dbd26f75 | ||
|
|
f3f052ba36 | ||
|
|
1bc90b992b | ||
|
|
fc37887a21 | ||
|
|
984658f5e9 | ||
|
|
4ed1476531 | ||
|
|
ca69e1a2f5 | ||
|
|
20f73cb756 | ||
|
|
4e2fba404d | ||
|
|
7943f7f697 | ||
|
|
7c397f5722 |
@@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
||||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ Dify 依赖以下工具和库:
|
||||
|
||||
Dify 由后端和前端组成。通过 `cd api/` 导航到后端目录,然后按照 [后端 README](api/README.md) 进行安装。在另一个终端中,通过 `cd web/` 导航到前端目录,然后按照 [前端 README](web/README.md) 进行安装。
|
||||
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/getting-started/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
查看 [安装常见问题解答](https://docs.dify.ai/v/zh-hans/learn-more/faq/install-faq) 以获取常见问题列表和故障排除步骤。
|
||||
|
||||
### 5. 在浏览器中访问 Dify
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ Dify はバックエンドとフロントエンドから構成されています
|
||||
まず`cd api/`でバックエンドのディレクトリに移動し、[Backend README](api/README.md)に従ってインストールします。
|
||||
次に別のターミナルで、`cd web/`でフロントエンドのディレクトリに移動し、[Frontend README](web/README.md)に従ってインストールしてください。
|
||||
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/getting-started/faq/install-faq) を確認してください。
|
||||
よくある問題とトラブルシューティングの手順については、[installation FAQ](https://docs.dify.ai/v/japanese/learn-more/faq/install-faq) を確認してください。
|
||||
|
||||
### 5. ブラウザで dify にアクセスする
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class HostedAzureOpenAiConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
HOSTED_OPENAI_API_KEY: Optional[str] = Field(
|
||||
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
|
||||
description='',
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PositiveInt
|
||||
|
||||
@@ -8,32 +7,32 @@ class MyScaleConfig(BaseModel):
|
||||
MyScale configs
|
||||
"""
|
||||
|
||||
MYSCALE_HOST: Optional[str] = Field(
|
||||
MYSCALE_HOST: str = Field(
|
||||
description='MyScale host',
|
||||
default=None,
|
||||
default='localhost',
|
||||
)
|
||||
|
||||
MYSCALE_PORT: Optional[PositiveInt] = Field(
|
||||
MYSCALE_PORT: PositiveInt = Field(
|
||||
description='MyScale port',
|
||||
default=8123,
|
||||
)
|
||||
|
||||
MYSCALE_USER: Optional[str] = Field(
|
||||
MYSCALE_USER: str = Field(
|
||||
description='MyScale user',
|
||||
default=None,
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_PASSWORD: Optional[str] = Field(
|
||||
MYSCALE_PASSWORD: str = Field(
|
||||
description='MyScale password',
|
||||
default=None,
|
||||
default='',
|
||||
)
|
||||
|
||||
MYSCALE_DATABASE: Optional[str] = Field(
|
||||
MYSCALE_DATABASE: str = Field(
|
||||
description='MyScale database name',
|
||||
default=None,
|
||||
default='default',
|
||||
)
|
||||
|
||||
MYSCALE_FTS_PARAMS: Optional[str] = Field(
|
||||
MYSCALE_FTS_PARAMS: str = Field(
|
||||
description='MyScale fts index parameters',
|
||||
default=None,
|
||||
default='',
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
)
|
||||
|
||||
if last_id is not None:
|
||||
last_segment = DocumentSegment.query.get(str(last_id))
|
||||
last_segment = db.session.get(DocumentSegment, str(last_id))
|
||||
if last_segment:
|
||||
query = query.filter(
|
||||
DocumentSegment.position > last_segment.position)
|
||||
|
||||
@@ -117,7 +117,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
member = Account.query.get(str(member_id))
|
||||
member = db.session.get(Account, str(member_id))
|
||||
if not member:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@@ -21,14 +21,43 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunApi(Resource):
|
||||
workflow_run_fields = {
|
||||
'id': fields.String,
|
||||
'workflow_id': fields.String,
|
||||
'status': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'outputs': fields.Raw,
|
||||
'error': fields.String,
|
||||
'total_steps': fields.Integer,
|
||||
'total_tokens': fields.Integer,
|
||||
'created_at': fields.DateTime,
|
||||
'finished_at': fields.DateTime,
|
||||
'elapsed_time': fields.Float,
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_run_fields)
|
||||
def get(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Get a workflow task running detail
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first()
|
||||
return workflow_run
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""
|
||||
@@ -88,5 +117,5 @@ class WorkflowTaskStopApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run')
|
||||
api.add_resource(WorkflowRunApi, '/workflows/run/<string:workflow_id>', '/workflows/run')
|
||||
api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop')
|
||||
|
||||
@@ -64,6 +64,7 @@ User Input:
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n"
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
"[\"question1\",\"question2\",\"question3\"]\n"
|
||||
)
|
||||
|
||||
@@ -103,7 +103,7 @@ class TokenBufferMemory:
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
|
||||
@@ -27,9 +27,9 @@ parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
|
||||
@@ -113,6 +113,11 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
if system:
|
||||
extra_model_kwargs['system'] = system
|
||||
|
||||
# Add the new header for claude-3-5-sonnet-20240620 model
|
||||
extra_headers = {}
|
||||
if model == "claude-3-5-sonnet-20240620":
|
||||
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
||||
if tools:
|
||||
extra_model_kwargs['tools'] = [
|
||||
self._transform_tool_prompt(tool) for tool in tools
|
||||
@@ -121,6 +126,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@@ -130,6 +136,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
messages=prompt_message_dicts,
|
||||
stream=stream,
|
||||
extra_headers=extra_headers,
|
||||
**model_parameters,
|
||||
**extra_model_kwargs
|
||||
)
|
||||
@@ -138,7 +145,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
- gpt-4o-mini
|
||||
- gpt-4o-mini-2024-07-18
|
||||
- gpt-4-turbo
|
||||
- gpt-4-turbo-2024-04-09
|
||||
- gpt-4-turbo-preview
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
model: gpt-4o-mini-2024-07-18
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini-2024-07-18
|
||||
en_US: gpt-4o-mini-2024-07-18
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,44 @@
|
||||
model: gpt-4o-mini
|
||||
label:
|
||||
zh_Hans: gpt-4o-mini
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.15'
|
||||
output: '0.60'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,4 +1,5 @@
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
- openai/gpt-4-32k
|
||||
- openai/gpt-3.5-turbo
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
model: openai/gpt-4o-mini
|
||||
label:
|
||||
en_US: gpt-4o-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 16384
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "0.15"
|
||||
output: "0.60"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
|
After Width: | Height: | Size: 9.2 KiB |
|
After Width: | Height: | Size: 9.5 KiB |
238
api/core/model_runtime/model_providers/sagemaker/llm/llm.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('access_key')
|
||||
secret_key = credentials.get('secret_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": prompt_messages[0].content,
|
||||
"parameters": { "stop" : stop},
|
||||
"history" : []
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
|
||||
assistant_text = response_model['Body'].read().decode('utf8')
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, 0, 0)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
return 0
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='temperature',
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
type=ParameterType.FLOAT,
|
||||
use_template='top_p',
|
||||
label=I18nObject(
|
||||
zh_Hans='Top P',
|
||||
en_US='Top P'
|
||||
)
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
type=ParameterType.INT,
|
||||
use_template='max_tokens',
|
||||
min=1,
|
||||
max=credentials.get('context_length', 2048),
|
||||
default=512,
|
||||
label=I18nObject(
|
||||
zh_Hans='最大生成长度',
|
||||
en_US='Max Tokens'
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type == LLMMode.CHAT:
|
||||
print(f"completion_type : {LLMMode.CHAT.value}")
|
||||
|
||||
if completion_type == LLMMode.COMPLETION:
|
||||
print(f"completion_type : {LLMMode.COMPLETION.value}")
|
||||
|
||||
features = []
|
||||
|
||||
support_function_call = credentials.get('support_function_call', False)
|
||||
if support_function_call:
|
||||
features.append(ModelFeature.TOOL_CALL)
|
||||
|
||||
support_vision = credentials.get('support_vision', False)
|
||||
if support_vision:
|
||||
features.append(ModelFeature.VISION)
|
||||
|
||||
context_length = credentials.get('context_length', 2048)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=features,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
ModelPropertyKey.CONTEXT_SIZE: context_length
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
return entity
|
||||
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Cohere rerank model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str):
|
||||
inputs = [query_input]*len(docs)
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=rerank_endpoint,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": inputs,
|
||||
"docs": docs
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
scores = json_obj['scores']
|
||||
return scores if isinstance(scores, list) else [scores]
|
||||
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=docs
|
||||
)
|
||||
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
candidate_docs = []
|
||||
|
||||
scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint)
|
||||
for idx in range(len(scores)):
|
||||
candidate_docs.append({"content" : docs[idx], "score": scores[idx]})
|
||||
|
||||
sorted(candidate_docs, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
line = 3
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(candidate_docs):
|
||||
rerank_document = RerankDocument(
|
||||
index=idx,
|
||||
text=result.get('content'),
|
||||
score=result.get('score', -100.0)
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
if rerank_document.score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
else:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(
|
||||
model=model,
|
||||
docs=rerank_documents
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SageMakerProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
125
api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Normal file
@@ -0,0 +1,125 @@
|
||||
provider: sagemaker
|
||||
label:
|
||||
zh_Hans: Sagemaker
|
||||
en_US: Sagemaker
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
description:
|
||||
en_US: Customized model on Sagemaker
|
||||
zh_Hans: Sagemaker上的私有化部署的模型
|
||||
background: "#ECE9E3"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy customized model on Sagemaker
|
||||
zh_Hans: 如何在Sagemaker上的私有化部署的模型
|
||||
url:
|
||||
en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint
|
||||
zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: sagemaker_endpoint
|
||||
label:
|
||||
en_US: sagemaker endpoint
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请输出你的Sagemaker推理端点
|
||||
en_US: Enter your Sagemaker Inference endpoint
|
||||
- variable: aws_access_key_id
|
||||
required: false
|
||||
label:
|
||||
en_US: Access Key (If not provided, credentials are obtained from the running environment.)
|
||||
zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。)
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Access Key
|
||||
zh_Hans: 在此输入您的 Access Key
|
||||
- variable: aws_secret_access_key
|
||||
required: false
|
||||
label:
|
||||
en_US: Secret Access Key
|
||||
zh_Hans: Secret Access Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Secret Access Key
|
||||
zh_Hans: 在此输入您的 Secret Access Key
|
||||
- variable: aws_region
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 地区
|
||||
type: select
|
||||
default: us-east-1
|
||||
options:
|
||||
- value: us-east-1
|
||||
label:
|
||||
en_US: US East (N. Virginia)
|
||||
zh_Hans: 美国东部 (弗吉尼亚北部)
|
||||
- value: us-west-2
|
||||
label:
|
||||
en_US: US West (Oregon)
|
||||
zh_Hans: 美国西部 (俄勒冈州)
|
||||
- value: ap-southeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Singapore)
|
||||
zh_Hans: 亚太地区 (新加坡)
|
||||
- value: ap-northeast-1
|
||||
label:
|
||||
en_US: Asia Pacific (Tokyo)
|
||||
zh_Hans: 亚太地区 (东京)
|
||||
- value: eu-central-1
|
||||
label:
|
||||
en_US: Europe (Frankfurt)
|
||||
zh_Hans: 欧洲 (法兰克福)
|
||||
- value: us-gov-west-1
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
- value: cn-north-1
|
||||
label:
|
||||
en_US: AWS Beijing (cn-north-1)
|
||||
zh_Hans: 中国北京 (cn-north-1)
|
||||
- value: cn-northwest-1
|
||||
label:
|
||||
en_US: AWS Ningxia (cn-northwest-1)
|
||||
zh_Hans: 中国宁夏 (cn-northwest-1)
|
||||
@@ -0,0 +1,214 @@
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
BATCH_SIZE = 20
|
||||
CONTEXT_SIZE=8192
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def batch_generator(generator, batch_size):
|
||||
while True:
|
||||
batch = list(itertools.islice(generator, batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
class SageMakerEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
|
||||
def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]):
|
||||
response_model = sm_client.invoke_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
Body=json.dumps(
|
||||
{
|
||||
"inputs": content_list,
|
||||
"parameters": {},
|
||||
"is_query" : False,
|
||||
"instruction" : ''
|
||||
}
|
||||
),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
embeddings = json_obj['embeddings']
|
||||
return embeddings
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
try:
|
||||
line = 1
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
|
||||
line = 2
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
|
||||
line = 3
|
||||
truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ]
|
||||
|
||||
batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE)
|
||||
all_embeddings = []
|
||||
|
||||
line = 4
|
||||
for batch in batches:
|
||||
embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
line = 5
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=0 # It's not SAAS API, usage is meaningless
|
||||
)
|
||||
line = 6
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=all_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
print("validate_credentials ok....")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE,
|
||||
ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
|
After Width: | Height: | Size: 9.0 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
@@ -0,0 +1,6 @@
|
||||
- step-1-8k
|
||||
- step-1-32k
|
||||
- step-1-128k
|
||||
- step-1-256k
|
||||
- step-1v-8k
|
||||
- step-1v-32k
|
||||
328
api/core/model_runtime/model_providers/stepfun/llm/llm.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model, zh_Hans=model),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get('function_calling_type') == 'tool_call'
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
use_template='temperature',
|
||||
label=I18nObject(en_US='Temperature', zh_Hans='温度'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='max_tokens',
|
||||
use_template='max_tokens',
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens', 1024)),
|
||||
label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
use_template='top_p',
|
||||
label=I18nObject(en_US='Top P', zh_Hans='Top P'),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def _add_custom_parameters(self, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
|
||||
|
||||
def _add_function_call(self, model: str, credentials: dict) -> None:
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
if model_schema and {
|
||||
ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
|
||||
}.intersection(model_schema.features or []):
|
||||
credentials['function_calling_type'] = 'tool_call'
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API format
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(PromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "text",
|
||||
"text": message_content.data
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": message_content.data,
|
||||
}
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = []
|
||||
for function_call in message.tool_calls:
|
||||
message_dict["tool_calls"].append({
|
||||
"id": function_call.id,
|
||||
"type": function_call.type,
|
||||
"function": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments
|
||||
}
|
||||
})
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
|
||||
"""
|
||||
Extract tool calls from response
|
||||
|
||||
:param response_tool_calls: response tool calls
|
||||
:return: list of tool calls
|
||||
"""
|
||||
tool_calls = []
|
||||
if response_tool_calls:
|
||||
for response_tool_call in response_tool_calls:
|
||||
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
|
||||
arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
|
||||
)
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=response_tool_call["id"] if response_tool_call.get("id") else "",
|
||||
type=response_tool_call["type"] if response_tool_call.get("type") else "",
|
||||
function=function
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ''
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = "Unknown"
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_name: str):
|
||||
if not tool_name:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
break
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
finish_reason = chunk_json['choices'][0].get('finish_reason')
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
delta_content = delta.get('content')
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta_content
|
||||
elif 'text' in choice:
|
||||
choice_text = choice.get('text', '')
|
||||
if choice_text == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
# check payload indicator for completion
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(
|
||||
tool_calls=tools_calls,
|
||||
content=""
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
@@ -0,0 +1,25 @@
|
||||
model: step-1-128k
|
||||
label:
|
||||
zh_Hans: step-1-128k
|
||||
en_US: step-1-128k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 128000
|
||||
pricing:
|
||||
input: '0.04'
|
||||
output: '0.20'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,25 @@
|
||||
model: step-1-256k
|
||||
label:
|
||||
zh_Hans: step-1-256k
|
||||
en_US: step-1-256k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 256000
|
||||
pricing:
|
||||
input: '0.095'
|
||||
output: '0.300'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,28 @@
|
||||
model: step-1-32k
|
||||
label:
|
||||
zh_Hans: step-1-32k
|
||||
en_US: step-1-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,28 @@
|
||||
model: step-1-8k
|
||||
label:
|
||||
zh_Hans: step-1-8k
|
||||
en_US: step-1-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,25 @@
|
||||
model: step-1v-32k
|
||||
label:
|
||||
zh_Hans: step-1v-32k
|
||||
en_US: step-1v-32k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.070'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -0,0 +1,25 @@
|
||||
model: step-1v-8k
|
||||
label:
|
||||
zh_Hans: step-1v-8k
|
||||
en_US: step-1v-8k
|
||||
model_type: llm
|
||||
features:
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.005'
|
||||
output: '0.020'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
30
api/core/model_runtime/model_providers/stepfun/stepfun.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StepfunProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='step-1-8k',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
81
api/core/model_runtime/model_providers/stepfun/stepfun.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
provider: stepfun
|
||||
label:
|
||||
zh_Hans: 阶跃星辰
|
||||
en_US: Stepfun
|
||||
description:
|
||||
en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
|
||||
zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from stepfun
|
||||
zh_Hans: 从 stepfun 获取 API Key
|
||||
url:
|
||||
en_US: https://platform.stepfun.com/interface-key
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '8192'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '8192'
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not supported
|
||||
zh_Hans: 不支持
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
||||
@@ -35,3 +35,4 @@ parameter_rules:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
model: ernie-4.0-8k-Latest
|
||||
model: ernie-4.0-8k-latest
|
||||
label:
|
||||
en_US: Ernie-4.0-8K-Latest
|
||||
model_type: llm
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
model: ernie-4.0-turbo-8k
|
||||
label:
|
||||
en_US: Ernie-4.0-turbo-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 2048
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
en_US: Disable Search
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
@@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
model: ernie-character-8k-0321
|
||||
label:
|
||||
en_US: ERNIE-Character-8K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.95
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1.0
|
||||
default: 0.7
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 2
|
||||
max: 1024
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
@@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -28,3 +28,4 @@ parameter_rules:
|
||||
default: 1.0
|
||||
min: 1.0
|
||||
max: 2.0
|
||||
deprecated: true
|
||||
|
||||
@@ -97,6 +97,7 @@ class BaiduAccessToken:
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
@@ -137,7 +138,9 @@ class ErnieBotModel:
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-tutbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-tutbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
}
|
||||
|
||||
@@ -149,7 +152,8 @@ class ErnieBotModel:
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k'
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview'
|
||||
]
|
||||
|
||||
|
||||
@@ -453,9 +453,11 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
api_key = credentials.get('api_key') or "abc"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
api_key='abc',
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -44,15 +44,23 @@ class XinferenceRerankModel(RerankModel):
|
||||
docs=[]
|
||||
)
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
api_key = credentials.get('api_key')
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers)
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(response['results']):
|
||||
@@ -102,7 +110,7 @@ class XinferenceRerankModel(RerankModel):
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
|
||||
@@ -99,9 +99,9 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
}
|
||||
|
||||
def _speech2text_invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
@@ -121,17 +121,24 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
api_key = credentials.get('api_key')
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language = language,
|
||||
prompt = prompt,
|
||||
response_format = response_format,
|
||||
temperature = temperature
|
||||
)
|
||||
try:
|
||||
handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers)
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language=language,
|
||||
prompt=prompt,
|
||||
response_format=response_format,
|
||||
temperature=temperature
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
return response["text"]
|
||||
|
||||
|
||||
@@ -43,16 +43,17 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
|
||||
api_key = credentials.get('api_key')
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers)
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(e)
|
||||
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
"""
|
||||
for convenience, the response json is like:
|
||||
class Embedding(TypedDict):
|
||||
@@ -106,7 +107,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||
@@ -117,7 +118,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
@@ -151,7 +152,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@@ -186,7 +187,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
|
||||
@@ -46,3 +46,12 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的Model UID
|
||||
en_US: Enter the model uid
|
||||
- variable: api_key
|
||||
label:
|
||||
zh_Hans: API密钥
|
||||
en_US: API key
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的API密钥
|
||||
en_US: Enter the api key
|
||||
|
||||
@@ -7,8 +7,8 @@ _import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -36,7 +36,7 @@ class AnalyticdbConfig(BaseModel):
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
_instance = None
|
||||
_init = False
|
||||
@@ -45,7 +45,7 @@ class AnalyticdbVector(BaseVector):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self, collection_name: str, config: AnalyticdbConfig):
|
||||
# collection_name must be updated every time
|
||||
self._collection_name = collection_name.lower()
|
||||
@@ -105,7 +105,7 @@ class AnalyticdbVector(BaseVector):
|
||||
raise ValueError(
|
||||
f"failed to create namespace {self.config.namespace}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
@@ -149,7 +149,7 @@ class AnalyticdbVector(BaseVector):
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection_if_not_exists(dimension)
|
||||
@@ -199,7 +199,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
@@ -260,7 +260,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
score_threshold = (
|
||||
@@ -291,7 +291,7 @@ class AnalyticdbVector(BaseVector):
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
def delete(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
@@ -316,17 +316,18 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
|
||||
)
|
||||
config = current_app.config
|
||||
|
||||
# TODO handle optional params
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
AnalyticdbConfig(
|
||||
access_key_id=config.get("ANALYTICDB_KEY_ID"),
|
||||
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
|
||||
region_id=config.get("ANALYTICDB_REGION_ID"),
|
||||
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
|
||||
account=config.get("ANALYTICDB_ACCOUNT"),
|
||||
account_password=config.get("ANALYTICDB_PASSWORD"),
|
||||
namespace=config.get("ANALYTICDB_NAMESPACE"),
|
||||
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@ from typing import Any, Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb import QueryResult, Settings
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -111,7 +111,8 @@ class ChromaVector(BaseVector):
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -133,15 +134,14 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
config = current_app.config
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=config.get('CHROMA_HOST'),
|
||||
port=int(config.get('CHROMA_PORT')),
|
||||
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
|
||||
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
|
||||
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
|
||||
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
|
||||
host=dify_config.CHROMA_HOST,
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
|
||||
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,10 +3,10 @@ import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -275,15 +275,14 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MILVUS, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
database=config.get('MILVUS_DATABASE'),
|
||||
host=dify_config.MILVUS_HOST,
|
||||
port=dify_config.MILVUS_PORT,
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
secure=dify_config.MILVUS_SECURE,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from clickhouse_connect import get_client
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -156,15 +156,14 @@ class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
config=MyScaleConfig(
|
||||
host=config.get("MYSCALE_HOST", "localhost"),
|
||||
port=int(config.get("MYSCALE_PORT", 8123)),
|
||||
user=config.get("MYSCALE_USER", "default"),
|
||||
password=config.get("MYSCALE_PASSWORD", ""),
|
||||
database=config.get("MYSCALE_DATABASE", "default"),
|
||||
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
|
||||
host=dify_config.MYSCALE_HOST,
|
||||
port=dify_config.MYSCALE_PORT,
|
||||
user=dify_config.MYSCALE_USER,
|
||||
password=dify_config.MYSCALE_PASSWORD,
|
||||
database=dify_config.MYSCALE_DATABASE,
|
||||
fts_params=dify_config.MYSCALE_FTS_PARAMS,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,11 +4,11 @@ import ssl
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import current_app
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -257,14 +257,13 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=config.get('OPENSEARCH_HOST'),
|
||||
port=config.get('OPENSEARCH_PORT'),
|
||||
user=config.get('OPENSEARCH_USER'),
|
||||
password=config.get('OPENSEARCH_PASSWORD'),
|
||||
secure=config.get('OPENSEARCH_SECURE'),
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Any
|
||||
|
||||
import numpy
|
||||
import oracledb
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -44,11 +44,11 @@ class OracleVectorConfig(BaseModel):
|
||||
|
||||
SQL_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
id varchar2(100)
|
||||
id varchar2(100)
|
||||
,text CLOB NOT NULL
|
||||
,meta JSON
|
||||
,embedding vector NOT NULL
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@@ -219,14 +219,13 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
config=OracleVectorConfig(
|
||||
host=config.get("ORACLE_HOST"),
|
||||
port=config.get("ORACLE_PORT"),
|
||||
user=config.get("ORACLE_USER"),
|
||||
password=config.get("ORACLE_PASSWORD"),
|
||||
database=config.get("ORACLE_DATABASE"),
|
||||
host=dify_config.ORACLE_HOST,
|
||||
port=dify_config.ORACLE_PORT,
|
||||
user=dify_config.ORACLE_USER,
|
||||
password=dify_config.ORACLE_PASSWORD,
|
||||
database=dify_config.ORACLE_DATABASE,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from flask import current_app
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -12,6 +11,7 @@ from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -93,7 +93,7 @@ class PGVectoRS(BaseVector):
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
vector vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
@@ -233,15 +233,15 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
|
||||
dim = len(embeddings.embed_query("pgvecto_rs"))
|
||||
config = current_app.config
|
||||
|
||||
return PGVectoRS(
|
||||
collection_name=collection_name,
|
||||
config=PgvectoRSConfig(
|
||||
host=config.get('PGVECTO_RS_HOST'),
|
||||
port=config.get('PGVECTO_RS_PORT'),
|
||||
user=config.get('PGVECTO_RS_USER'),
|
||||
password=config.get('PGVECTO_RS_PASSWORD'),
|
||||
database=config.get('PGVECTO_RS_DATABASE'),
|
||||
host=dify_config.PGVECTO_RS_HOST,
|
||||
port=dify_config.PGVECTO_RS_PORT,
|
||||
user=dify_config.PGVECTO_RS_USER,
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
),
|
||||
dim=dim
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -45,7 +45,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
text TEXT NOT NULL,
|
||||
meta JSONB NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
"""
|
||||
|
||||
|
||||
@@ -185,14 +185,13 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
config=PGVectorConfig(
|
||||
host=config.get("PGVECTOR_HOST"),
|
||||
port=config.get("PGVECTOR_PORT"),
|
||||
user=config.get("PGVECTOR_USER"),
|
||||
password=config.get("PGVECTOR_PASSWORD"),
|
||||
database=config.get("PGVECTOR_DATABASE"),
|
||||
host=dify_config.PGVECTOR_HOST,
|
||||
port=dify_config.PGVECTOR_PORT,
|
||||
user=dify_config.PGVECTOR_USER,
|
||||
password=dify_config.PGVECTOR_PASSWORD,
|
||||
database=dify_config.PGVECTOR_DATABASE,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from qdrant_client.http.models import (
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -361,6 +362,8 @@ class QdrantVector(BaseVector):
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -444,11 +447,11 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
|
||||
grpc_port=config.get('QDRANT_GRPC_PORT'),
|
||||
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
@@ -19,6 +18,7 @@ try:
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -85,7 +85,7 @@ class RelytVector(BaseVector):
|
||||
document TEXT NOT NULL,
|
||||
metadata JSON NOT NULL,
|
||||
embedding vector({dimension}) NOT NULL
|
||||
) using heap;
|
||||
) using heap;
|
||||
""")
|
||||
session.execute(create_statement)
|
||||
index_statement = sql_text(f"""
|
||||
@@ -313,15 +313,14 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.RELYT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=config.get('RELYT_HOST'),
|
||||
port=config.get('RELYT_PORT'),
|
||||
user=config.get('RELYT_USER'),
|
||||
password=config.get('RELYT_PASSWORD'),
|
||||
database=config.get('RELYT_DATABASE'),
|
||||
host=dify_config.RELYT_HOST,
|
||||
port=dify_config.RELYT_PORT,
|
||||
user=dify_config.RELYT_USER,
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
),
|
||||
group_id=dataset.id
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model import document, enum
|
||||
from tcvectordb.model import index as vdb_index
|
||||
from tcvectordb.model.document import Filter
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -212,16 +212,15 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TENCENT, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
config=TencentConfig(
|
||||
url=config.get('TENCENT_VECTOR_DB_URL'),
|
||||
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
|
||||
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
|
||||
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
|
||||
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
|
||||
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
|
||||
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL,
|
||||
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
|
||||
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
|
||||
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
|
||||
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
|
||||
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,12 +3,12 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.orm import Session, declarative_base
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
@@ -198,8 +198,8 @@ class TiDBVector(BaseVector):
|
||||
with Session(self._engine) as session:
|
||||
select_statement = sql_text(
|
||||
f"""SELECT meta, text, distance FROM (
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
|
||||
FROM {self._collection_name}
|
||||
ORDER BY distance
|
||||
LIMIT {top_k}
|
||||
) t WHERE distance < {distance};"""
|
||||
@@ -234,15 +234,14 @@ class TiDBVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))
|
||||
|
||||
config = current_app.config
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
config=TiDBVectorConfig(
|
||||
host=config.get('TIDB_VECTOR_HOST'),
|
||||
port=config.get('TIDB_VECTOR_PORT'),
|
||||
user=config.get('TIDB_VECTOR_USER'),
|
||||
password=config.get('TIDB_VECTOR_PASSWORD'),
|
||||
database=config.get('TIDB_VECTOR_DATABASE'),
|
||||
program_name=config.get('APPLICATION_NAME'),
|
||||
host=dify_config.TIDB_VECTOR_HOST,
|
||||
port=dify_config.TIDB_VECTOR_PORT,
|
||||
user=dify_config.TIDB_VECTOR_USER,
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD,
|
||||
database=dify_config.TIDB_VECTOR_DATABASE,
|
||||
program_name=dify_config.APPLICATION_NAME,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@@ -37,8 +36,7 @@ class Vector:
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
if self._dataset.index_struct_dict:
|
||||
vector_type = self._dataset.index_struct_dict['type']
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -216,7 +216,8 @@ class WeaviateVector(BaseVector):
|
||||
if score > score_threshold:
|
||||
doc.metadata['score'] = score
|
||||
docs.append(doc)
|
||||
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -281,9 +282,9 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
|
||||
api_key=current_app.config.get('WEAVIATE_API_KEY'),
|
||||
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE
|
||||
),
|
||||
attributes=attributes
|
||||
)
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Union
|
||||
from urllib.parse import unquote
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@@ -94,9 +94,9 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = current_app.config['ETL_TYPE']
|
||||
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
|
||||
unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||
if etl_type == 'Unstructured':
|
||||
if file_extension == '.xlsx' or file_extension == '.xls':
|
||||
extractor = ExcelExtractor(file_path)
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@@ -49,7 +49,7 @@ class NotionExtractor(BaseExtractor):
|
||||
self._notion_access_token = self._get_access_token(tenant_id,
|
||||
self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment "
|
||||
|
||||
@@ -8,8 +8,8 @@ from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from docx import Document as DocxDocument
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@@ -96,10 +96,9 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
storage.save(file_key, rel.target_part.blob)
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self.tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=file_key,
|
||||
name=file_key,
|
||||
size=0,
|
||||
@@ -114,7 +113,7 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = f"}/files/{upload_file.id}/image-preview)"
|
||||
image_map[rel.target_part] = f""
|
||||
|
||||
return image_map
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
@@ -48,7 +47,7 @@ class BaseIndexProcessor(ABC):
|
||||
# The user-defined segmentation rule
|
||||
rules = processing_rule['rules']
|
||||
segmentation = rules["segmentation"]
|
||||
max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
|
||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
|
||||
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
|
||||
|
||||
|
||||
@@ -30,3 +30,4 @@
|
||||
- feishu
|
||||
- feishu_base
|
||||
- slack
|
||||
- tianditu
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
import random
|
||||
from base64 import b64decode
|
||||
from typing import Any, Union
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
|
||||
result = []
|
||||
|
||||
for image in response.data:
|
||||
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
|
||||
meta={'mime_type': 'image/png'},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(blob=blob_image,
|
||||
meta={'mime_type': mime_type},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _decode_image(base64_image: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
|
||||
it assumes 'image/png' as the default.
|
||||
|
||||
:param base64_image: Base64 encoded image string
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
if DallE3Tool._is_plain_base64(base64_image):
|
||||
return 'image/png', base64.b64decode(base64_image)
|
||||
else:
|
||||
return DallE3Tool._extract_mime_and_data(base64_image)
|
||||
|
||||
@staticmethod
|
||||
def _is_plain_base64(encoded_str: str) -> bool:
|
||||
"""
|
||||
Check if the given encoded string is plain base64 without a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string
|
||||
:return: True if the string is plain base64, False otherwise
|
||||
"""
|
||||
return not encoded_str.startswith('data:image')
|
||||
|
||||
@staticmethod
|
||||
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
|
||||
"""
|
||||
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
|
||||
|
||||
:param encoded_str: Base64 encoded image string with MIME type prefix
|
||||
:return: A tuple containing the MIME type and the decoded image bytes
|
||||
"""
|
||||
mime_type = encoded_str.split(';')[0].split(':')[1]
|
||||
image_data_base64 = encoded_str.split(',')[1]
|
||||
decoded_data = base64.b64decode(image_data_base64)
|
||||
return mime_type, decoded_data
|
||||
|
||||
@staticmethod
|
||||
def _generate_random_id(length=8):
|
||||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256"><rect width="256" height="256" fill="none"/><rect x="32" y="48" width="192" height="160" rx="8" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><circle cx="156" cy="100" r="12"/><path d="M147.31,164,173,138.34a8,8,0,0,1,11.31,0L224,178.06" fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/><path d="M32,168.69l54.34-54.35a8,8,0,0,1,11.32,0L191.31,208" fill="none" stroke="#1553ed" stroke-linecap="round" stroke-linejoin="round" stroke-width="16"/></svg>
|
||||
|
After Width: | Height: | Size: 617 B |
22
api/core/tools/provider/builtin/getimgai/getimgai.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.getimgai.tools.text2image import Text2ImageTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class GetImgAIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
# Example validation using the text2image tool
|
||||
Text2ImageTool().fork_tool_runtime(
|
||||
runtime={"credentials": credentials}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"prompt": "A fire egg",
|
||||
"response_format": "url",
|
||||
"style": "photorealism",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
|
||||
29
api/core/tools/provider/builtin/getimgai/getimgai.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Matri Qi
|
||||
name: getimgai
|
||||
label:
|
||||
en_US: getimg.ai
|
||||
zh_CN: getimg.ai
|
||||
description:
|
||||
en_US: GetImg API integration for image generation and scraping.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- image
|
||||
credentials_for_provider:
|
||||
getimg_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: getimg.ai API Key
|
||||
placeholder:
|
||||
en_US: Please input your getimg.ai API key
|
||||
help:
|
||||
en_US: Get your getimg.ai API key from your getimg.ai account settings. If you are using a self-hosted version, you may enter any key at your convenience.
|
||||
url: https://dashboard.getimg.ai/api-keys
|
||||
base_url:
|
||||
type: text-input
|
||||
required: false
|
||||
label:
|
||||
en_US: getimg.ai server's Base URL
|
||||
placeholder:
|
||||
en_US: https://api.getimg.ai/v1
|
||||
59
api/core/tools/provider/builtin/getimgai/getimgai_appx.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GetImgAIApp:
|
||||
def __init__(self, api_key: str | None = None, base_url: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.getimg.ai/v1'
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required")
|
||||
|
||||
def _prepare_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
data: Mapping[str, Any] | None = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
retries: int = 3,
|
||||
backoff_factor: float = 0.3,
|
||||
) -> Mapping[str, Any] | None:
|
||||
for i in range(retries):
|
||||
try:
|
||||
response = requests.request(method, url, json=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500:
|
||||
time.sleep(backoff_factor * (2 ** i))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
def text2image(
|
||||
self, mode: str, **kwargs
|
||||
):
|
||||
data = kwargs['params']
|
||||
if not data.get('prompt'):
|
||||
raise ValueError("Prompt is required")
|
||||
|
||||
endpoint = f'{self.base_url}/{mode}/text-to-image'
|
||||
headers = self._prepare_headers()
|
||||
logger.debug(f"Send request to {endpoint=} body={data}")
|
||||
response = self._request('POST', endpoint, data, headers)
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate getimg.ai after multiple retries")
|
||||
return response
|
||||
39
api/core/tools/provider/builtin/getimgai/tools/text2image.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.getimgai.getimgai_appx import GetImgAIApp
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class Text2ImageTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url'])
|
||||
|
||||
options = {
|
||||
'style': tool_parameters.get('style'),
|
||||
'prompt': tool_parameters.get('prompt'),
|
||||
'aspect_ratio': tool_parameters.get('aspect_ratio'),
|
||||
'output_format': tool_parameters.get('output_format', 'jpeg'),
|
||||
'response_format': tool_parameters.get('response_format', 'url'),
|
||||
'width': tool_parameters.get('width'),
|
||||
'height': tool_parameters.get('height'),
|
||||
'steps': tool_parameters.get('steps'),
|
||||
'negative_prompt': tool_parameters.get('negative_prompt'),
|
||||
'prompt_2': tool_parameters.get('prompt_2'),
|
||||
}
|
||||
options = {k: v for k, v in options.items() if v}
|
||||
|
||||
text2image_result = app.text2image(
|
||||
mode=tool_parameters.get('mode', 'essential-v2'),
|
||||
params=options,
|
||||
wait=True
|
||||
)
|
||||
|
||||
if not isinstance(text2image_result, str):
|
||||
text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4)
|
||||
|
||||
if not text2image_result:
|
||||
return self.create_text_message("getimg.ai request failed.")
|
||||
|
||||
return self.create_text_message(text2image_result)
|
||||
167
api/core/tools/provider/builtin/getimgai/tools/text2image.yaml
Normal file
@@ -0,0 +1,167 @@
|
||||
identity:
|
||||
name: text2image
|
||||
author: Matri Qi
|
||||
label:
|
||||
en_US: text2image
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate image via getimg.ai.
|
||||
llm: This tool is used to generate image from prompt or image via https://getimg.ai.
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: prompt
|
||||
human_description:
|
||||
en_US: The text prompt used to generate the image. The getimg.aier will generate an image based on this prompt.
|
||||
llm_description: this prompt text will be used to generate image.
|
||||
form: llm
|
||||
- name: mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: mode
|
||||
human_description:
|
||||
en_US: The getimg.ai mode to use. The mode determines the endpoint used to generate the image.
|
||||
form: form
|
||||
options:
|
||||
- value: "essential-v2"
|
||||
label:
|
||||
en_US: essential-v2
|
||||
- value: stable-diffusion-xl
|
||||
label:
|
||||
en_US: stable-diffusion-xl
|
||||
- value: stable-diffusion
|
||||
label:
|
||||
en_US: stable-diffusion
|
||||
- value: latent-consistency
|
||||
label:
|
||||
en_US: latent-consistency
|
||||
- name: style
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: style
|
||||
human_description:
|
||||
en_US: The style preset to use. The style preset guides the generation towards a particular style. It's just efficient for `Essential V2` mode.
|
||||
form: form
|
||||
options:
|
||||
- value: photorealism
|
||||
label:
|
||||
en_US: photorealism
|
||||
- value: anime
|
||||
label:
|
||||
en_US: anime
|
||||
- value: art
|
||||
label:
|
||||
en_US: art
|
||||
- name: aspect_ratio
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "aspect ratio"
|
||||
human_description:
|
||||
en_US: The aspect ratio of the generated image. It's just efficient for `Essential V2` mode.
|
||||
form: form
|
||||
options:
|
||||
- value: "1:1"
|
||||
label:
|
||||
en_US: "1:1"
|
||||
- value: "4:5"
|
||||
label:
|
||||
en_US: "4:5"
|
||||
- value: "5:4"
|
||||
label:
|
||||
en_US: "5:4"
|
||||
- value: "2:3"
|
||||
label:
|
||||
en_US: "2:3"
|
||||
- value: "3:2"
|
||||
label:
|
||||
en_US: "3:2"
|
||||
- value: "4:7"
|
||||
label:
|
||||
en_US: "4:7"
|
||||
- value: "7:4"
|
||||
label:
|
||||
en_US: "7:4"
|
||||
- name: output_format
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "output format"
|
||||
human_description:
|
||||
en_US: The file format of the generated image.
|
||||
form: form
|
||||
options:
|
||||
- value: jpeg
|
||||
label:
|
||||
en_US: jpeg
|
||||
- value: png
|
||||
label:
|
||||
en_US: png
|
||||
- name: response_format
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: "response format"
|
||||
human_description:
|
||||
en_US: The format in which the generated images are returned. Must be one of url or b64. URLs are only valid for 1 hour after the image has been generated.
|
||||
form: form
|
||||
options:
|
||||
- value: url
|
||||
label:
|
||||
en_US: url
|
||||
- value: b64
|
||||
label:
|
||||
en_US: b64
|
||||
- name: model
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: model
|
||||
human_description:
|
||||
en_US: Model ID supported by this pipeline and family. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: negative prompt
|
||||
human_description:
|
||||
en_US: Text input that will not guide the image generation. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
- name: prompt_2
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: prompt2
|
||||
human_description:
|
||||
en_US: Prompt sent to second tokenizer and text encoder. If not defined, prompt is used in both text-encoders. It's just efficient for `Stable Diffusion XL` mode.
|
||||
form: form
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: width
|
||||
human_description:
|
||||
en_US: he width of the generated image in pixels. Width needs to be multiple of 64.
|
||||
form: form
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: height
|
||||
human_description:
|
||||
en_US: he height of the generated image in pixels. Height needs to be multiple of 64.
|
||||
form: form
|
||||
- name: steps
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: steps
|
||||
human_description:
|
||||
en_US: The number of denoising steps. More steps usually can produce higher quality images, but take more time to generate. It's just efficient for `Stable Diffusion XL`, `Stable Diffusion`, `Latent Consistency` mode.
|
||||
form: form
|
||||
@@ -19,28 +19,29 @@ class JSONDeleteTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# Get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._delete(content, query)
|
||||
result = self._delete(content, query, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception as e:
|
||||
return self.create_text_message(f'Failed to delete JSON content: {str(e)}')
|
||||
|
||||
def _delete(self, origin_json: str, query: str) -> str:
|
||||
def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(origin_json)
|
||||
expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
if not matches:
|
||||
return json.dumps(input_data, ensure_ascii=True) # No changes if no matches found
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii) # No changes if no matches found
|
||||
|
||||
for match in matches:
|
||||
if isinstance(match.context.value, dict):
|
||||
# Delete key from dictionary
|
||||
@@ -53,7 +54,7 @@ class JSONDeleteTool(BuiltinTool):
|
||||
parent = match.context.parent
|
||||
if parent:
|
||||
del parent.value[match.path.fields[-1]]
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
raise Exception(f"Delete operation failed: {str(e)}")
|
||||
raise Exception(f"Delete operation failed: {str(e)}")
|
||||
|
||||
@@ -38,3 +38,15 @@ parameters:
|
||||
pt_BR: JSONPath query to locate the element to delete
|
||||
llm_description: JSONPath query to locate the element to delete
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@@ -19,31 +19,31 @@ class JSONParseTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
# get new value
|
||||
new_value = tool_parameters.get('new_value', '')
|
||||
if not new_value:
|
||||
return self.create_text_message('Invalid parameter new_value')
|
||||
|
||||
|
||||
# get insert position
|
||||
index = tool_parameters.get('index')
|
||||
|
||||
|
||||
# get create path
|
||||
create_path = tool_parameters.get('create_path', False)
|
||||
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._insert(content, query, new_value, index, create_path)
|
||||
result = self._insert(content, query, new_value, ensure_ascii, index, create_path)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to insert JSON content')
|
||||
|
||||
|
||||
def _insert(self, origin_json, query, new_value, index=None, create_path=False):
|
||||
def _insert(self, origin_json, query, new_value, ensure_ascii: bool, index=None, create_path=False):
|
||||
try:
|
||||
input_data = json.loads(origin_json)
|
||||
expr = parse(query)
|
||||
@@ -51,9 +51,9 @@ class JSONParseTool(BuiltinTool):
|
||||
new_value = json.loads(new_value)
|
||||
except json.JSONDecodeError:
|
||||
new_value = new_value
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
if not matches and create_path:
|
||||
# create new path
|
||||
path_parts = query.strip('$').strip('.').split('.')
|
||||
@@ -91,7 +91,7 @@ class JSONParseTool(BuiltinTool):
|
||||
else:
|
||||
# replace old value with new value
|
||||
match.full_path.update(input_data, new_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@@ -75,3 +75,15 @@ parameters:
|
||||
zh_Hans: 否
|
||||
pt_BR: "No"
|
||||
form: form
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@@ -19,33 +19,34 @@ class JSONParseTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get json filter
|
||||
json_filter = tool_parameters.get('json_filter', '')
|
||||
if not json_filter:
|
||||
return self.create_text_message('Invalid parameter json_filter')
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
result = self._extract(content, json_filter)
|
||||
result = self._extract(content, json_filter, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to extract JSON content')
|
||||
|
||||
# Extract data from JSON content
|
||||
def _extract(self, content: str, json_filter: str) -> str:
|
||||
def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(json_filter)
|
||||
result = [match.value for match in expr.find(input_data)]
|
||||
|
||||
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
|
||||
|
||||
if isinstance(result, dict | list):
|
||||
return json.dumps(result, ensure_ascii=True)
|
||||
return json.dumps(result, ensure_ascii=ensure_ascii)
|
||||
elif isinstance(result, str | int | float | bool) or result is None:
|
||||
return str(result)
|
||||
else:
|
||||
return repr(result)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@@ -38,3 +38,15 @@ parameters:
|
||||
pt_BR: JSON fields to be parsed
|
||||
llm_description: JSON fields to be parsed
|
||||
form: llm
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
@@ -19,61 +19,62 @@ class JSONReplaceTool(BuiltinTool):
|
||||
content = tool_parameters.get('content', '')
|
||||
if not content:
|
||||
return self.create_text_message('Invalid parameter content')
|
||||
|
||||
|
||||
# get query
|
||||
query = tool_parameters.get('query', '')
|
||||
if not query:
|
||||
return self.create_text_message('Invalid parameter query')
|
||||
|
||||
|
||||
# get replace value
|
||||
replace_value = tool_parameters.get('replace_value', '')
|
||||
if not replace_value:
|
||||
return self.create_text_message('Invalid parameter replace_value')
|
||||
|
||||
|
||||
# get replace model
|
||||
replace_model = tool_parameters.get('replace_model', '')
|
||||
if not replace_model:
|
||||
return self.create_text_message('Invalid parameter replace_model')
|
||||
|
||||
ensure_ascii = tool_parameters.get('ensure_ascii', True)
|
||||
try:
|
||||
if replace_model == 'pattern':
|
||||
# get replace pattern
|
||||
replace_pattern = tool_parameters.get('replace_pattern', '')
|
||||
if not replace_pattern:
|
||||
return self.create_text_message('Invalid parameter replace_pattern')
|
||||
result = self._replace_pattern(content, query, replace_pattern, replace_value)
|
||||
result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii)
|
||||
elif replace_model == 'key':
|
||||
result = self._replace_key(content, query, replace_value)
|
||||
result = self._replace_key(content, query, replace_value, ensure_ascii)
|
||||
elif replace_model == 'value':
|
||||
result = self._replace_value(content, query, replace_value)
|
||||
result = self._replace_value(content, query, replace_value, ensure_ascii)
|
||||
return self.create_text_message(str(result))
|
||||
except Exception:
|
||||
return self.create_text_message('Failed to replace JSON content')
|
||||
|
||||
# Replace pattern
|
||||
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str) -> str:
|
||||
def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
new_value = match.value.replace(replace_pattern, replace_value)
|
||||
match.full_path.update(input_data, new_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
# Replace key
|
||||
def _replace_key(self, content: str, query: str, replace_value: str) -> str:
|
||||
def _replace_key(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
parent = match.context.value
|
||||
if isinstance(parent, dict):
|
||||
@@ -86,21 +87,21 @@ class JSONReplaceTool(BuiltinTool):
|
||||
if isinstance(item, dict) and old_key in item:
|
||||
value = item.pop(old_key)
|
||||
item[replace_value] = value
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
# Replace value
|
||||
def _replace_value(self, content: str, query: str, replace_value: str) -> str:
|
||||
def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool) -> str:
|
||||
try:
|
||||
input_data = json.loads(content)
|
||||
expr = parse(query)
|
||||
|
||||
|
||||
matches = expr.find(input_data)
|
||||
|
||||
|
||||
for match in matches:
|
||||
match.full_path.update(input_data, replace_value)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=True)
|
||||
|
||||
return json.dumps(input_data, ensure_ascii=ensure_ascii)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
return str(e)
|
||||
|
||||
@@ -93,3 +93,15 @@ parameters:
|
||||
zh_Hans: 字符串替换
|
||||
pt_BR: replace string
|
||||
form: form
|
||||
- name: ensure_ascii
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
en_US: Ensure ASCII
|
||||
zh_Hans: 确保 ASCII
|
||||
pt_BR: Ensure ASCII
|
||||
human_description:
|
||||
en_US: Ensure the JSON output is ASCII encoded
|
||||
zh_Hans: 确保输出的 JSON 是 ASCII 编码
|
||||
pt_BR: Ensure the JSON output is ASCII encoded
|
||||
form: form
|
||||
|
||||
1
api/core/tools/provider/builtin/spider/_assets/icon.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg height="30" width="30" viewBox="0 0 36 34" xml:space="preserve" xmlns="http://www.w3.org/2000/svg" class="fill-accent-foreground transition-all group-hover:scale-110"><title>Spider v1 Logo</title><path fill-rule="evenodd" clip-rule="evenodd" d="M9.13883 7.06589V0.164429L13.0938 0.164429V6.175L14.5178 7.4346C15.577 6.68656 16.7337 6.27495 17.945 6.27495C19.1731 6.27495 20.3451 6.69807 21.4163 7.46593L22.8757 6.175V0.164429L26.8307 0.164429V7.06589V7.95679L26.1634 8.54706L24.0775 10.3922C24.3436 10.8108 24.5958 11.2563 24.8327 11.7262L26.0467 11.4215L28.6971 8.08749L31.793 10.5487L28.7257 14.407L28.3089 14.9313L27.6592 15.0944L26.2418 15.4502C26.3124 15.7082 26.3793 15.9701 26.4422 16.2355L28.653 16.6566L29.092 16.7402L29.4524 17.0045L35.3849 21.355L33.0461 24.5444L27.474 20.4581L27.0719 20.3816C27.1214 21.0613 27.147 21.7543 27.147 22.4577C27.147 22.5398 27.1466 22.6214 27.1459 22.7024L29.5889 23.7911L30.3219 24.1177L30.62 24.8629L33.6873 32.5312L30.0152 34L27.246 27.0769L26.7298 26.8469C25.5612 32.2432 22.0701 33.8808 17.945 33.8808C13.8382 33.8808 10.3598 32.2577 9.17593 26.9185L8.82034 27.0769L6.05109 34L2.37897 32.5312L5.44629 24.8629L5.74435 24.1177L6.47743 23.7911L8.74487 22.7806C8.74366 22.6739 8.74305 22.5663 8.74305 22.4577C8.74305 21.7616 8.76804 21.0758 8.81654 20.4028L8.52606 20.4581L2.95395 24.5444L0.615112 21.355L6.54761 17.0045L6.908 16.7402L7.34701 16.6566L9.44264 16.2575C9.50917 15.9756 9.5801 15.6978 9.65528 15.4242L8.34123 15.0944L7.69155 14.9313L7.27471 14.407L4.20739 10.5487L7.30328 8.08749L9.95376 11.4215L11.0697 11.7016C11.3115 11.2239 11.5692 10.7716 11.8412 10.3473L9.80612 8.54706L9.13883 7.95679V7.06589Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
14
api/core/tools/provider/builtin/spider/spider.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.spider.spiderApp import Spider
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class SpiderProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
app = Spider(api_key=credentials["spider_api_key"])
|
||||
app.scrape_url(url="https://spider.cloud")
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
27
api/core/tools/provider/builtin/spider/spider.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
identity:
|
||||
author: William Espegren
|
||||
name: spider
|
||||
label:
|
||||
en_US: Spider
|
||||
zh_CN: Spider
|
||||
description:
|
||||
en_US: Spider API integration, returning LLM-ready data by scraping & crawling websites.
|
||||
zh_CN: Spider API 集成,通过爬取和抓取网站返回 LLM-ready 数据。
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- search
|
||||
- utilities
|
||||
credentials_for_provider:
|
||||
spider_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Spider API Key
|
||||
zh_CN: Spider API 密钥
|
||||
placeholder:
|
||||
en_US: Please input your Spider API key
|
||||
zh_CN: 请输入您的 Spider API 密钥
|
||||
help:
|
||||
en_US: Get your Spider API key from your Spider dashboard
|
||||
zh_CN: 从您的 Spider 仪表板中获取 Spider API 密钥。
|
||||
url: https://spider.cloud/
|
||||
237
api/core/tools/provider/builtin/spider/spiderApp.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
from typing import Literal, Optional, TypedDict
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class RequestParamsDict(TypedDict, total=False):
|
||||
url: Optional[str]
|
||||
request: Optional[Literal["http", "chrome", "smart"]]
|
||||
limit: Optional[int]
|
||||
return_format: Optional[Literal["raw", "markdown", "html2text", "text", "bytes"]]
|
||||
tld: Optional[bool]
|
||||
depth: Optional[int]
|
||||
cache: Optional[bool]
|
||||
budget: Optional[dict[str, int]]
|
||||
locale: Optional[str]
|
||||
cookies: Optional[str]
|
||||
stealth: Optional[bool]
|
||||
headers: Optional[dict[str, str]]
|
||||
anti_bot: Optional[bool]
|
||||
metadata: Optional[bool]
|
||||
viewport: Optional[dict[str, int]]
|
||||
encoding: Optional[str]
|
||||
subdomains: Optional[bool]
|
||||
user_agent: Optional[str]
|
||||
store_data: Optional[bool]
|
||||
gpt_config: Optional[list[str]]
|
||||
fingerprint: Optional[bool]
|
||||
storageless: Optional[bool]
|
||||
readability: Optional[bool]
|
||||
proxy_enabled: Optional[bool]
|
||||
respect_robots: Optional[bool]
|
||||
query_selector: Optional[str]
|
||||
full_resources: Optional[bool]
|
||||
request_timeout: Optional[int]
|
||||
run_in_background: Optional[bool]
|
||||
skip_config_checks: Optional[bool]
|
||||
|
||||
|
||||
class Spider:
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Spider with an API key.
|
||||
|
||||
:param api_key: A string of the API key for Spider. Defaults to the SPIDER_API_KEY environment variable.
|
||||
:raises ValueError: If no API key is provided.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("SPIDER_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def api_post(
|
||||
self,
|
||||
endpoint: str,
|
||||
data: dict,
|
||||
stream: bool,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Send a POST request to the specified API endpoint.
|
||||
|
||||
:param endpoint: The API endpoint to which the POST request is sent.
|
||||
:param data: The data (dictionary) to be sent in the POST request.
|
||||
:param stream: Boolean indicating if the response should be streamed.
|
||||
:return: The JSON response or the raw response stream if stream is True.
|
||||
"""
|
||||
headers = self._prepare_headers(content_type)
|
||||
response = self._post_request(
|
||||
f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream
|
||||
)
|
||||
|
||||
if stream:
|
||||
return response
|
||||
elif response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
self._handle_error(response, f"post to {endpoint}")
|
||||
|
||||
def api_get(
|
||||
self, endpoint: str, stream: bool, content_type: str = "application/json"
|
||||
):
|
||||
"""
|
||||
Send a GET request to the specified endpoint.
|
||||
|
||||
:param endpoint: The API endpoint from which to retrieve data.
|
||||
:return: The JSON decoded response.
|
||||
"""
|
||||
headers = self._prepare_headers(content_type)
|
||||
response = self._get_request(
|
||||
f"https://api.spider.cloud/v1/{endpoint}", headers, stream
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
self._handle_error(response, f"get from {endpoint}")
|
||||
|
||||
def get_credits(self):
|
||||
"""
|
||||
Retrieve the account's remaining credits.
|
||||
|
||||
:return: JSON response containing the number of credits left.
|
||||
"""
|
||||
return self.api_get("credits", stream=False)
|
||||
|
||||
def scrape_url(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Scrape data from the specified URL.
|
||||
|
||||
:param url: The URL from which to scrape data.
|
||||
:param params: Optional dictionary of additional parameters for the scrape request.
|
||||
:return: JSON response containing the scraping results.
|
||||
"""
|
||||
|
||||
# Add { "return_format": "markdown" } to the params if not already present
|
||||
if "return_format" not in params:
|
||||
params["return_format"] = "markdown"
|
||||
|
||||
# Set limit to 1
|
||||
params["limit"] = 1
|
||||
|
||||
return self.api_post(
|
||||
"crawl", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def crawl_url(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Start crawling at the specified URL.
|
||||
|
||||
:param url: The URL to begin crawling.
|
||||
:param params: Optional dictionary with additional parameters to customize the crawl.
|
||||
:param stream: Boolean indicating if the response should be streamed. Defaults to False.
|
||||
:return: JSON response or the raw response stream if streaming enabled.
|
||||
"""
|
||||
|
||||
# Add { "return_format": "markdown" } to the params if not already present
|
||||
if "return_format" not in params:
|
||||
params["return_format"] = "markdown"
|
||||
|
||||
return self.api_post(
|
||||
"crawl", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def links(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Retrieve links from the specified URL.
|
||||
|
||||
:param url: The URL from which to extract links.
|
||||
:param params: Optional parameters for the link retrieval request.
|
||||
:return: JSON response containing the links.
|
||||
"""
|
||||
return self.api_post(
|
||||
"links", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def extract_contacts(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Extract contact information from the specified URL.
|
||||
|
||||
:param url: The URL from which to extract contact information.
|
||||
:param params: Optional parameters for the contact extraction.
|
||||
:return: JSON response containing extracted contact details.
|
||||
"""
|
||||
return self.api_post(
|
||||
"pipeline/extract-contacts",
|
||||
{"url": url, **(params or {})},
|
||||
stream,
|
||||
content_type,
|
||||
)
|
||||
|
||||
def label(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[RequestParamsDict] = None,
|
||||
stream: bool = False,
|
||||
content_type: str = "application/json",
|
||||
):
|
||||
"""
|
||||
Apply labeling to data extracted from the specified URL.
|
||||
|
||||
:param url: The URL to label data from.
|
||||
:param params: Optional parameters to guide the labeling process.
|
||||
:return: JSON response with labeled data.
|
||||
"""
|
||||
return self.api_post(
|
||||
"pipeline/label", {"url": url, **(params or {})}, stream, content_type
|
||||
)
|
||||
|
||||
def _prepare_headers(self, content_type: str = "application/json"):
|
||||
return {
|
||||
"Content-Type": content_type,
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "Spider-Client/0.0.27",
|
||||
}
|
||||
|
||||
def _post_request(self, url: str, data, headers, stream=False):
|
||||
return requests.post(url, headers=headers, json=data, stream=stream)
|
||||
|
||||
def _get_request(self, url: str, headers, stream=False):
|
||||
return requests.get(url, headers=headers, stream=stream)
|
||||
|
||||
def _delete_request(self, url: str, headers, stream=False):
|
||||
return requests.delete(url, headers=headers, stream=stream)
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(
|
||||
f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}"
|
||||
)
|
||||
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.spider.spiderApp import Spider
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class ScrapeTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
# initialize the app object with the api key
|
||||
app = Spider(api_key=self.runtime.credentials['spider_api_key'])
|
||||
|
||||
url = tool_parameters['url']
|
||||
mode = tool_parameters['mode']
|
||||
|
||||
options = {
|
||||
'limit': tool_parameters.get('limit', 0),
|
||||
'depth': tool_parameters.get('depth', 0),
|
||||
'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [],
|
||||
'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [],
|
||||
'readability': tool_parameters.get('readability', False),
|
||||
}
|
||||
|
||||
result = ""
|
||||
|
||||
try:
|
||||
if mode == 'scrape':
|
||||
scrape_result = app.scrape_url(
|
||||
url=url,
|
||||
params=options,
|
||||
)
|
||||
|
||||
for i in scrape_result:
|
||||
result += "URL: " + i.get('url', '') + "\n"
|
||||
result += "CONTENT: " + i.get('content', '') + "\n\n"
|
||||
elif mode == 'crawl':
|
||||
crawl_result = app.crawl_url(
|
||||
url=tool_parameters['url'],
|
||||
params=options,
|
||||
)
|
||||
for i in crawl_result:
|
||||
result += "URL: " + i.get('url', '') + "\n"
|
||||
result += "CONTENT: " + i.get('content', '') + "\n\n"
|
||||
except Exception as e:
|
||||
return self.create_text_message("An error occured", str(e))
|
||||
|
||||
return self.create_text_message(result)
|
||||
@@ -0,0 +1,102 @@
|
||||
identity:
|
||||
name: scraper_crawler
|
||||
author: William Espegren
|
||||
label:
|
||||
en_US: Web Scraper & Crawler
|
||||
zh_Hans: 网页抓取与爬虫
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for scraping & crawling webpages. Input should be a url.
|
||||
zh_Hans: 用于抓取和爬取网页的工具。输入应该是一个网址。
|
||||
llm: A tool for scraping & crawling webpages. Input should be a url.
|
||||
parameters:
|
||||
- name: url
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URL
|
||||
zh_Hans: 网址
|
||||
human_description:
|
||||
en_US: url to be scraped or crawled
|
||||
zh_Hans: 要抓取或爬取的网址
|
||||
llm_description: url to either be scraped or crawled
|
||||
form: llm
|
||||
- name: mode
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- value: scrape
|
||||
label:
|
||||
en_US: scrape
|
||||
zh_Hans: 抓取
|
||||
- value: crawl
|
||||
label:
|
||||
en_US: crawl
|
||||
zh_Hans: 爬取
|
||||
default: crawl
|
||||
label:
|
||||
en_US: Mode
|
||||
zh_Hans: 模式
|
||||
human_description:
|
||||
en_US: used for selecting to either scrape the website or crawl the entire website following subpages
|
||||
zh_Hans: 用于选择抓取网站或爬取整个网站及其子页面
|
||||
form: form
|
||||
- name: limit
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: maximum number of pages to crawl
|
||||
zh_Hans: 最大爬取页面数
|
||||
human_description:
|
||||
en_US: specify the maximum number of pages to crawl per website. the crawler will stop after reaching this limit.
|
||||
zh_Hans: 指定每个网站要爬取的最大页面数。爬虫将在达到此限制后停止。
|
||||
form: form
|
||||
min: 0
|
||||
default: 0
|
||||
- name: depth
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: maximum depth of pages to crawl
|
||||
zh_Hans: 最大爬取深度
|
||||
human_description:
|
||||
en_US: the crawl limit for maximum depth.
|
||||
zh_Hans: 最大爬取深度的限制。
|
||||
form: form
|
||||
min: 0
|
||||
default: 0
|
||||
- name: blacklist
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: url patterns to exclude
|
||||
zh_Hans: 要排除的URL模式
|
||||
human_description:
|
||||
en_US: blacklist a set of paths that you do not want to crawl. you can use regex patterns to help with the list.
|
||||
zh_Hans: 指定一组不想爬取的路径。您可以使用正则表达式模式来帮助定义列表。
|
||||
placeholder:
|
||||
en_US: /blog/*, /about
|
||||
form: form
|
||||
- name: whitelist
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: URL patterns to include
|
||||
zh_Hans: 要包含的URL模式
|
||||
human_description:
|
||||
en_US: Whitelist a set of paths that you want to crawl, ignoring all other routes that do not match the patterns. You can use regex patterns to help with the list.
|
||||
zh_Hans: 指定一组要爬取的路径,忽略所有不匹配模式的其他路由。您可以使用正则表达式模式来帮助定义列表。
|
||||
placeholder:
|
||||
en_US: /blog/*, /about
|
||||
form: form
|
||||
- name: readability
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Pre-process the content for LLM usage
|
||||
zh_Hans: 仅返回页面的主要内容
|
||||
human_description:
|
||||
en_US: Use Mozilla's readability to pre-process the content for reading. This may drastically improve the content for LLM usage.
|
||||
zh_Hans: 如果启用,爬虫将仅返回页面的主要内容,不包括标题、导航、页脚等。
|
||||
form: form
|
||||
default: false
|
||||
21
api/core/tools/provider/builtin/tianditu/_assets/icon.svg
Normal file
|
After Width: | Height: | Size: 23 KiB |
21
api/core/tools/provider/builtin/tianditu/tianditu.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.tianditu.tools.poisearch import PoiSearchTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class TiandituProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
PoiSearchTool().fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(user_id='',
|
||||
tool_parameters={
|
||||
'content': '北京',
|
||||
'specify': '156110000',
|
||||
})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
32
api/core/tools/provider/builtin/tianditu/tianditu.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
identity:
|
||||
author: Listeng
|
||||
name: tianditu
|
||||
label:
|
||||
en_US: Tianditu
|
||||
zh_Hans: 天地图
|
||||
pt_BR: Tianditu
|
||||
description:
|
||||
en_US: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
|
||||
zh_Hans: 天地图工具可以调用天地图的接口,实现中国区域内的地名搜索、地理编码、静态地图等功能。
|
||||
pt_BR: The Tianditu tool provided the functions of place name search, geocoding, static maps generation, etc. in China region.
|
||||
icon: icon.svg
|
||||
tags:
|
||||
- utilities
|
||||
- travel
|
||||
credentials_for_provider:
|
||||
tianditu_api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Tianditu API Key
|
||||
zh_Hans: 天地图Key
|
||||
pt_BR: Tianditu API key
|
||||
placeholder:
|
||||
en_US: Please input your Tianditu API key
|
||||
zh_Hans: 请输入你的天地图Key
|
||||
pt_BR: Please input your Tianditu API key
|
||||
help:
|
||||
en_US: Get your Tianditu API key from Tianditu
|
||||
zh_Hans: 获取您的天地图Key
|
||||
pt_BR: Get your Tianditu API key from Tianditu
|
||||
url: http://lbs.tianditu.gov.cn/home.html
|
||||
33
api/core/tools/provider/builtin/tianditu/tools/geocoder.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GeocoderTool(BuiltinTool):
|
||||
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
base_url = 'http://api.tianditu.gov.cn/geocoder'
|
||||
|
||||
keyword = tool_parameters.get('keyword', '')
|
||||
if not keyword:
|
||||
return self.create_text_message('Invalid parameter keyword')
|
||||
|
||||
tk = self.runtime.credentials['tianditu_api_key']
|
||||
|
||||
params = {
|
||||
'keyWord': keyword,
|
||||
}
|
||||
|
||||
result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json()
|
||||
|
||||
return self.create_json_message(result)
|
||||