mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 05:54:02 +00:00
Compare commits
44 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f58912fd7 | ||
|
|
0c746f5c5a | ||
|
|
a8cedea15a | ||
|
|
87832ede17 | ||
|
|
4d99c689f0 | ||
|
|
28b26f67e2 | ||
|
|
b934232411 | ||
|
|
2f120786fd | ||
|
|
6075fee556 | ||
|
|
de584807e1 | ||
|
|
a1285cbf15 | ||
|
|
cf1f6f3961 | ||
|
|
f4d97ef9fa | ||
|
|
28883e80d4 | ||
|
|
a0f74cdd9d | ||
|
|
296bf443a8 | ||
|
|
af7be9bdd7 | ||
|
|
2cfd5568e1 | ||
|
|
faf40a42bc | ||
|
|
97c972f14d | ||
|
|
3fa5204b0c | ||
|
|
5a756ca981 | ||
|
|
01f9feff9f | ||
|
|
2757494265 | ||
|
|
b88a8f7bb1 | ||
|
|
b4225bedb5 | ||
|
|
a82b4d315a | ||
|
|
3d92784bd4 | ||
|
|
c06e766d7e | ||
|
|
4a3d15b6de | ||
|
|
a798dcfae9 | ||
|
|
b4a170cb8a | ||
|
|
665318da3d | ||
|
|
66cdf577f5 | ||
|
|
891218615e | ||
|
|
a938e1f184 | ||
|
|
7c7ee633c1 | ||
|
|
18af84e193 | ||
|
|
025b859c7e | ||
|
|
0e239a4f71 | ||
|
|
ca85b0afbe | ||
|
|
a0a9461f79 | ||
|
|
6a2eb5f442 | ||
|
|
0c5892bcb6 |
19
.github/workflows/api-model-runtime-tests.yml
vendored
19
.github/workflows/api-model-runtime-tests.yml
vendored
@@ -31,28 +31,19 @@ jobs:
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||
MOCK_SWITCH: true
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
cache: 'pip'
|
||||
cache-dependency-path: ./api/requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest
|
||||
pip install -r api/requirements.txt
|
||||
run: pip install -r ./api/requirements.txt
|
||||
|
||||
- name: Run pytest
|
||||
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
||||
|
||||
34
.github/workflows/style.yml
vendored
Normal file
34
.github/workflows/style.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- deploy/dev
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 18
|
||||
cache: yarn
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
run: |
|
||||
cd ./web
|
||||
yarn install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
run: |
|
||||
cd ./web
|
||||
yarn run lint
|
||||
@@ -65,6 +65,7 @@ WEAVIATE_BATCH_SIZE=100
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
QDRANT_CLIENT_TIMEOUT=20
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
|
||||
@@ -36,6 +36,7 @@ DEFAULTS = {
|
||||
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'QDRANT_CLIENT_TIMEOUT': 20,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
@@ -87,7 +88,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.4.2"
|
||||
self.CURRENT_VERSION = "0.4.4"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -141,15 +141,9 @@ class AppListApi(Resource):
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
except ProviderTokenNotInitError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
model_instance = None
|
||||
|
||||
if not model_instance:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Default System Reasoning Model available. Please configure "
|
||||
f"in the Settings -> Model Provider.")
|
||||
else:
|
||||
if model_instance:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = model_instance.provider
|
||||
model_dict['name'] = model_instance.model
|
||||
|
||||
@@ -58,7 +58,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -78,7 +78,7 @@ class CompletionMessageApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -153,7 +153,7 @@ class ChatMessageApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -38,7 +38,7 @@ class RuleGenerateApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return rules
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ class MessageMoreLikeThisApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -256,7 +256,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
yield "data: " + json.dumps(
|
||||
api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
@@ -296,7 +296,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@@ -156,6 +156,9 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
|
||||
if segment.status != 'completed':
|
||||
raise NotFound('Segment is not completed, enable or disable function is not allowed')
|
||||
|
||||
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
|
||||
cache_result = redis_client.get(document_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
|
||||
@@ -54,7 +54,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -70,7 +70,7 @@ class CompletionApi(InstalledAppResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -134,7 +134,7 @@ class ChatApi(InstalledAppResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -175,7 +175,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
|
||||
@@ -104,7 +104,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
@@ -131,7 +131,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
@@ -169,7 +169,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@@ -54,7 +54,7 @@ class UniversalChatAudioApi(UniversalChatResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -89,7 +89,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -126,7 +126,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
|
||||
@@ -133,7 +133,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@@ -50,7 +50,7 @@ class AudioApi(AppApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -31,7 +31,7 @@ class CompletionApi(AppApiResource):
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -67,7 +67,7 @@ class CompletionApi(AppApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -96,7 +96,7 @@ class ChatApi(AppApiResource):
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
@@ -131,7 +131,7 @@ class ChatApi(AppApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -171,7 +171,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
|
||||
@@ -52,7 +52,7 @@ class AudioApi(WebApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
||||
@@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -124,7 +124,7 @@ class ChatApi(WebApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -164,7 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
|
||||
@@ -138,7 +138,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
@@ -165,7 +165,7 @@ def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
except ModelCurrentlyNotSupportError:
|
||||
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
|
||||
except InvokeError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
|
||||
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
|
||||
except ValueError as e:
|
||||
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
|
||||
except Exception:
|
||||
@@ -202,7 +202,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@@ -75,7 +75,7 @@ class AgentApplicationRunner(AppRunner):
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
@@ -153,7 +153,7 @@ class AgentApplicationRunner(AppRunner):
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -50,7 +50,7 @@ class AppRunner:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
@@ -107,7 +107,7 @@ class AppRunner:
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def originze_prompt_messages(self, app_record: App,
|
||||
def organize_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
@@ -183,7 +183,7 @@ class AppRunner:
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
))
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
@@ -193,7 +193,8 @@ class AppRunner:
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
)
|
||||
),
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
@@ -226,7 +227,8 @@ class AppRunner:
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=invoke_result
|
||||
llm_result=invoke_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
@@ -242,7 +244,7 @@ class AppRunner:
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result)
|
||||
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
@@ -263,5 +265,6 @@ class AppRunner:
|
||||
)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=llm_result
|
||||
llm_result=llm_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
@@ -79,7 +79,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
@@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
@@ -132,16 +133,16 @@ class BasicApplicationRunner(AppRunner):
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
@@ -164,7 +165,7 @@ class BasicApplicationRunner(AppRunner):
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||
AnnotationReplyEvent
|
||||
@@ -312,8 +312,11 @@ class GenerateTaskPipeline:
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
))
|
||||
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
||||
), PublishFrom.TASK_PIPELINE)
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
|
||||
from flask import current_app, Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.application_queue_manager import PublishFrom
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
@@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||
|
||||
return final_output
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Conversation, Message, MessageFile, App
|
||||
@@ -169,15 +169,18 @@ class ApplicationManager:
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e)
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e)
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import queue
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Generator, Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
@@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class ApplicationQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
@@ -61,11 +67,14 @@ class ApplicationQueueManager:
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self.stop_listen()
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent())
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
@@ -75,76 +84,83 @@ class ApplicationQueueManager:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
||||
def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEvent(
|
||||
chunk=chunk
|
||||
))
|
||||
), pub_from)
|
||||
|
||||
def publish_message_replace(self, text: str) -> None:
|
||||
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
:param text: text
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageReplaceEvent(
|
||||
text=text
|
||||
))
|
||||
), pub_from)
|
||||
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish retriever resources
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
|
||||
|
||||
def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
||||
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish annotation reply
|
||||
:param message_annotation_id: message annotation id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
|
||||
|
||||
def publish_message_end(self, llm_result: LLMResult) -> None:
|
||||
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message end
|
||||
:param llm_result: llm result
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_agent_thought: message agent thought
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=message_agent_thought.id
|
||||
))
|
||||
), pub_from)
|
||||
|
||||
def publish_error(self, e) -> None:
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
))
|
||||
), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish(self, event: AppQueueEvent) -> None:
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
@@ -162,6 +178,9 @@ class ApplicationQueueManager:
|
||||
if isinstance(event, QueueStopEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
@@ -173,7 +192,7 @@ class ApplicationQueueManager:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result != f"{user_prefix}-{user_id}":
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
@@ -187,7 +206,6 @@ class ApplicationQueueManager:
|
||||
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
redis_client.delete(stopped_cache_key)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
@@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
||||
db.session.add(message_agent_thought)
|
||||
db.session.commit()
|
||||
|
||||
self.queue_manager.publish_agent_thought(message_agent_thought)
|
||||
self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
return message_agent_thought
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Union
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, DatasetQuery
|
||||
@@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish_retriever_resources(resource)
|
||||
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -65,7 +65,8 @@ class FileExtractor:
|
||||
elif file_extension == '.pdf':
|
||||
loader = PdfLoader(file_path, upload_file=upload_file)
|
||||
elif file_extension in ['.md', '.markdown']:
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url)
|
||||
loader = UnstructuredMarkdownLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else MarkdownLoader(file_path, autodetect_encoding=True)
|
||||
elif file_extension in ['.htm', '.html']:
|
||||
loader = HTMLLoader(file_path)
|
||||
elif file_extension == '.docx':
|
||||
@@ -84,7 +85,8 @@ class FileExtractor:
|
||||
loader = UnstructuredXmlLoader(file_path, unstructured_api_url)
|
||||
else:
|
||||
# txt
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url)
|
||||
loader = UnstructuredTextLoader(file_path, unstructured_api_url) if is_automatic \
|
||||
else TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
if file_extension == '.xlsx':
|
||||
loader = ExcelLoader(file_path)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional, List, Dict, Tuple, Iterator
|
||||
@@ -9,6 +10,7 @@ from pydantic import BaseModel
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
@@ -18,6 +20,8 @@ from core.model_runtime.utils import encoders
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProviderConfiguration(BaseModel):
|
||||
"""
|
||||
@@ -168,6 +172,14 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
||||
|
||||
def delete_custom_credentials(self) -> None:
|
||||
@@ -190,6 +202,14 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.delete(provider_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||
-> Optional[dict]:
|
||||
"""
|
||||
@@ -311,6 +331,14 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.add(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
||||
"""
|
||||
Delete custom model credentials.
|
||||
@@ -332,6 +360,14 @@ class ProviderConfiguration(BaseModel):
|
||||
db.session.delete(provider_model_record)
|
||||
db.session.commit()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
@@ -484,7 +520,13 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_models.extend(
|
||||
[
|
||||
ModelWithProviderEntity(
|
||||
**m.dict(),
|
||||
model=m.model,
|
||||
label=m.label,
|
||||
model_type=m.model_type,
|
||||
features=m.features,
|
||||
fetch_from=m.fetch_from,
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
@@ -533,7 +575,13 @@ class ProviderConfiguration(BaseModel):
|
||||
for m in models:
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
**m.dict(),
|
||||
model=m.model,
|
||||
label=m.label,
|
||||
model_type=m.model_type,
|
||||
features=m.features,
|
||||
fetch_from=m.fetch_from,
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
)
|
||||
@@ -544,20 +592,30 @@ class ProviderConfiguration(BaseModel):
|
||||
if model_configuration.model_type not in model_types:
|
||||
continue
|
||||
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(model_configuration.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model,
|
||||
model_configuration.credentials
|
||||
try:
|
||||
custom_model_schema = (
|
||||
provider_instance.get_model_instance(model_configuration.model_type)
|
||||
.get_customizable_model_schema_from_credentials(
|
||||
model_configuration.model,
|
||||
model_configuration.credentials
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.warning(f'get custom model schema failed, {ex}')
|
||||
continue
|
||||
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
**custom_model_schema.dict(),
|
||||
model=custom_model_schema.model,
|
||||
label=custom_model_schema.label,
|
||||
model_type=custom_model_schema.model_type,
|
||||
features=custom_model_schema.features,
|
||||
fetch_from=custom_model_schema.fetch_from,
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
)
|
||||
|
||||
@@ -61,7 +61,7 @@ class Extensible:
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r') as f:
|
||||
with open(builtin_file_path, 'r', encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
@@ -93,7 +93,7 @@ class Extensible:
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
|
||||
@@ -58,7 +58,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(self.config.get('variable')))
|
||||
.format(self.variable))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
@@ -74,7 +74,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
|
||||
self.config.get('variable'),
|
||||
self.variable,
|
||||
e
|
||||
))
|
||||
|
||||
@@ -87,6 +87,10 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
|
||||
if 'result' not in response_json:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
|
||||
.format(self.config.get('variable')))
|
||||
.format(self.variable))
|
||||
|
||||
if not isinstance(response_json['result'], str):
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string"
|
||||
.format(self.variable))
|
||||
|
||||
return response_json['result']
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Weather Search",
|
||||
"zh-Hans": "天气查询"
|
||||
},
|
||||
"form_schema": [
|
||||
{
|
||||
"type": "select",
|
||||
"label": {
|
||||
"en-US": "Temperature Unit",
|
||||
"zh-Hans": "温度单位"
|
||||
},
|
||||
"variable": "temperature_unit",
|
||||
"required": true,
|
||||
"options": [
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Fahrenheit",
|
||||
"zh-Hans": "华氏度"
|
||||
},
|
||||
"value": "fahrenheit"
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Centigrade",
|
||||
"zh-Hans": "摄氏度"
|
||||
},
|
||||
"value": "centigrade"
|
||||
}
|
||||
],
|
||||
"default": "centigrade",
|
||||
"placeholder": "Please select temperature unit"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
|
||||
|
||||
class WeatherSearch(ExternalDataTool):
|
||||
"""
|
||||
The name of custom type must be unique, keep the same with directory and file name.
|
||||
"""
|
||||
name: str = "weather_search"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
schema.json validation. It will be called when user save the config.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
config = {
|
||||
"temperature_unit": "centigrade"
|
||||
}
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the variables of form config
|
||||
:return:
|
||||
"""
|
||||
|
||||
if not config.get('temperature_unit'):
|
||||
raise ValueError('temperature unit is required')
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
city = inputs.get('city')
|
||||
temperature_unit = self.config.get('temperature_unit')
|
||||
|
||||
if temperature_unit == 'fahrenheit':
|
||||
return f'Weather in {city} is 32°F'
|
||||
else:
|
||||
return f'Weather in {city} is 0°C'
|
||||
51
api/core/helper/model_provider_cache.py
Normal file
51
api/core/helper/model_provider_cache.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "provider"
|
||||
MODEL = "provider_model"
|
||||
|
||||
|
||||
class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_provider_credentials
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@@ -59,7 +59,7 @@ class IndexingRunner:
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document)
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
@@ -113,15 +113,14 @@ class IndexingRunner:
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document)
|
||||
|
||||
# get the process rule
|
||||
processing_rule = db.session.query(DatasetProcessRule). \
|
||||
filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
|
||||
first()
|
||||
|
||||
# load file
|
||||
text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
@@ -238,14 +237,15 @@ class IndexingRunner:
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
for file_detail in file_details:
|
||||
# load data from file
|
||||
text_docs = FileExtractor.load(file_detail)
|
||||
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"],
|
||||
rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
|
||||
# load data from file
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')
|
||||
|
||||
# get splitter
|
||||
splitter = self._get_splitter(processing_rule)
|
||||
|
||||
@@ -382,13 +382,15 @@ class IndexingRunner:
|
||||
)
|
||||
total_segments += len(documents)
|
||||
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
if indexing_technique == 'high_quality' and embedding_model_type_instance:
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
@@ -457,7 +459,7 @@ class IndexingRunner:
|
||||
one_or_none()
|
||||
|
||||
if file_detail:
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=True)
|
||||
text_docs = FileExtractor.load(file_detail, is_automatic=automatic)
|
||||
elif dataset_document.data_source_type == 'notion_import':
|
||||
loader = NotionLoader.from_document(dataset_document)
|
||||
text_docs = loader.load()
|
||||
|
||||
@@ -8,6 +8,9 @@ class InvokeError(Exception):
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
|
||||
class InvokeConnectionError(InvokeError):
|
||||
"""Raised when the Invoke returns connection error."""
|
||||
|
||||
@@ -147,13 +147,15 @@ class AIModel(ABC):
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, 'r') as f:
|
||||
position_map = yaml.safe_load(f)
|
||||
with open(position_file_path, 'r', encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
# convert list to dict with key as model provider name, value as index
|
||||
position_map = {position: index for index, position in enumerate(positions)}
|
||||
|
||||
# traverse all model_schema_yaml_paths
|
||||
for model_schema_yaml_path in model_schema_yaml_paths:
|
||||
# read yaml data from yaml file
|
||||
with open(model_schema_yaml_path, 'r') as f:
|
||||
with open(model_schema_yaml_path, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
new_parameter_rules = []
|
||||
@@ -236,16 +238,6 @@ class AIModel(ABC):
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
if 'schema' in credentials:
|
||||
schema_dict = json.loads(credentials['schema'])
|
||||
|
||||
try:
|
||||
model_instance = AIModelEntity.parse_obj(schema_dict)
|
||||
return model_instance
|
||||
except ValidationError as e:
|
||||
logging.exception(f"Invalid model schema for {model}")
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
|
||||
@@ -165,7 +165,7 @@ class LargeLanguageModel(AIModel):
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint
|
||||
),
|
||||
credentials=credentials,
|
||||
|
||||
@@ -47,7 +47,7 @@ class ModelProvider(ABC):
|
||||
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
|
||||
yaml_data = {}
|
||||
if os.path.exists(yaml_path):
|
||||
with open(yaml_path, 'r') as f:
|
||||
with open(yaml_path, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
try:
|
||||
@@ -112,7 +112,7 @@ class ModelProvider(ABC):
|
||||
model_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
|
||||
and obj != AIModel):
|
||||
and obj != AIModel and obj.__module__ == mod.__name__):
|
||||
model_class = obj
|
||||
break
|
||||
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
openai: 0
|
||||
anthropic: 1
|
||||
azure_openai: 2
|
||||
google: 3
|
||||
replicate: 4
|
||||
huggingface_hub: 5
|
||||
cohere: 6
|
||||
zhipuai: 7
|
||||
baichuan: 8
|
||||
spark: 9
|
||||
minimax: 10
|
||||
tongyi: 11
|
||||
wenxin: 12
|
||||
jina: 13
|
||||
chatglm: 14
|
||||
xinference: 15
|
||||
openllm: 16
|
||||
localai: 17
|
||||
openai_api_compatible: 18
|
||||
- openai
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- cohere
|
||||
- togetherai
|
||||
- zhipuai
|
||||
- baichuan
|
||||
- spark
|
||||
- minimax
|
||||
- tongyi
|
||||
- wenxin
|
||||
- jina
|
||||
- chatglm
|
||||
- xinference
|
||||
- openllm
|
||||
- localai
|
||||
- openai_api_compatible
|
||||
@@ -309,7 +309,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=response.model,
|
||||
model=response.model or model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
|
||||
@@ -54,7 +54,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=tokens[i: i + max_chunks],
|
||||
@@ -62,7 +62,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += [data for data in embeddings]
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
@@ -73,7 +73,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[""],
|
||||
@@ -81,7 +81,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings[0]
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Generator, List, Optional, Union, cast
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
@@ -156,9 +156,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
completion_model = None
|
||||
if credentials['completion_type'] == 'chat_completion':
|
||||
completion_model = LLMMode.CHAT
|
||||
completion_model = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_model = LLMMode.COMPLETION
|
||||
completion_model = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
@@ -202,7 +202,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={ 'mode': completion_model } if completion_model else {},
|
||||
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
|
||||
class ModelProviderFactory:
|
||||
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# for cache in memory
|
||||
self.get_providers()
|
||||
|
||||
def get_providers(self) -> list[ProviderEntity]:
|
||||
"""
|
||||
Get all providers
|
||||
@@ -212,8 +216,10 @@ class ModelProviderFactory:
|
||||
# read _position.yaml file
|
||||
position_map = {}
|
||||
if os.path.exists(position_file_path):
|
||||
with open(position_file_path, 'r') as f:
|
||||
position_map = yaml.safe_load(f)
|
||||
with open(position_file_path, 'r', encoding='utf-8') as f:
|
||||
positions = yaml.safe_load(f)
|
||||
# convert list to dict with key as model provider name, value as index
|
||||
position_map = {position: index for index, position in enumerate(positions)}
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
for model_provider_dir_path in model_provider_dir_paths:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
gpt-4: 0
|
||||
gpt-4-32k: 1
|
||||
gpt-4-1106-preview: 2
|
||||
gpt-4-vision-preview: 3
|
||||
gpt-3.5-turbo: 4
|
||||
gpt-3.5-turbo-16k: 5
|
||||
gpt-3.5-turbo-1106: 6
|
||||
gpt-3.5-turbo-instruct: 7
|
||||
text-davinci-003: 8
|
||||
- gpt-4
|
||||
- gpt-4-32k
|
||||
- gpt-4-1106-preview
|
||||
- gpt-4-vision-preview
|
||||
- gpt-3.5-turbo
|
||||
- gpt-3.5-turbo-16k
|
||||
- gpt-3.5-turbo-16k-0613
|
||||
- gpt-3.5-turbo-1106
|
||||
- gpt-3.5-turbo-0613
|
||||
- gpt-3.5-turbo-instruct
|
||||
- text-davinci-003
|
||||
@@ -40,87 +40,4 @@ class _CommonOAI_API_Compat:
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout # Timeout
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=model_type,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000),
|
||||
ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('temperature', 1)),
|
||||
min=0,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('top_p', 1)),
|
||||
min=0,
|
||||
max=1
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
default=int(credentials.get('top_k', 1)),
|
||||
min=1,
|
||||
max=100
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.FREQUENCY_PENALTY.value,
|
||||
label=I18nObject(en_US="Frequency Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('frequency_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="PRESENCE Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('PRESENCE_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.MAX_TOKENS.value,
|
||||
label=I18nObject(en_US="Max Tokens"),
|
||||
type=ParameterType.INT,
|
||||
default=1024,
|
||||
min=1,
|
||||
max=int(credentials.get('max_tokens_to_sample', 4096)),
|
||||
)
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
)
|
||||
)
|
||||
|
||||
if model_type == ModelType.LLM:
|
||||
if credentials['mode'] == 'chat':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT
|
||||
elif credentials['mode'] == 'completion':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
||||
}
|
||||
@@ -158,7 +158,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
},
|
||||
parameter_rules=[
|
||||
@@ -196,9 +196,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.PRESENCE_PENALTY.value,
|
||||
label=I18nObject(en_US="PRESENCE Penalty"),
|
||||
label=I18nObject(en_US="Presence Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
default=float(credentials.get('PRESENCE_penalty', 0)),
|
||||
default=float(credentials.get('presence_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
),
|
||||
@@ -219,6 +219,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
)
|
||||
)
|
||||
|
||||
if credentials['mode'] == 'chat':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
|
||||
elif credentials['mode'] == 'completion':
|
||||
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
return entity
|
||||
|
||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
||||
@@ -261,7 +268,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
elif completion_type == LLMMode.COMPLETION:
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||
data['prompt'] = prompt_messages[0].content
|
||||
else:
|
||||
@@ -291,10 +298,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
stream=stream
|
||||
)
|
||||
|
||||
# Debug: Print request headers and json data
|
||||
logger.debug(f"Request headers: {headers}")
|
||||
logger.debug(f"Request JSON data: {data}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
|
||||
@@ -337,9 +340,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
)
|
||||
)
|
||||
|
||||
for chunk in response.iter_content(chunk_size=2048):
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'):
|
||||
if chunk:
|
||||
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
|
||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||
|
||||
chunk_json = None
|
||||
try:
|
||||
@@ -356,7 +359,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
choice = chunk_json['choices'][0]
|
||||
chunk_index = choice['index'] if 'index' in choice else chunk_index
|
||||
chunk_index += 1
|
||||
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
@@ -408,12 +411,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="End of stream."
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ provider: openai_api_compatible
|
||||
label:
|
||||
en_US: OpenAI-API-compatible
|
||||
description:
|
||||
en_US: All model providers compatible with OpenAI's API standard, such as Together.ai.
|
||||
zh_Hans: 兼容 OpenAI API 的模型供应商,例如 Together.ai。
|
||||
en_US: Model providers compatible with OpenAI's API standard, such as LM Studio.
|
||||
zh_Hans: 兼容 OpenAI API 的模型供应商,例如 LM Studio 。
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
|
||||
@@ -112,7 +112,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
|
||||
@@ -6,7 +6,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate import Open
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
@@ -198,7 +198,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': LLMMode.COMPLETION,
|
||||
ModelPropertyKey.MODE: LLMMode.COMPLETION.value,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
|
||||
PromptMessageRole, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.model_entities import ParameterRule, AIModelEntity, FetchFrom, ModelType, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.replicate._common import _CommonReplicate
|
||||
@@ -91,7 +91,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': model_type.value
|
||||
ModelPropertyKey.MODE: model_type.value
|
||||
},
|
||||
parameter_rules=self._get_customizable_model_parameter_rules(model, credentials)
|
||||
)
|
||||
|
||||
@@ -19,13 +19,23 @@ class SparkProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `claude-instant-1` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='spark-1.5',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='spark-3',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
<svg width="114" height="24" viewBox="0 0 114 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.21688 7.55431H1V5.74708H3.21688V2.30127H5.19279V5.74708H8.30124V7.55431H5.19279V14.8074C5.19279 15.3214 5.28918 15.6909 5.48195 15.9158C5.69079 16.1246 6.0442 16.2291 6.5422 16.2291H8.68679V18.0363H6.42171C5.26507 18.0363 4.43776 17.7792 3.93977 17.2652C3.45784 16.7511 3.21688 15.9398 3.21688 14.8314V7.55431Z" fill="black"/>
|
||||
<path d="M15.0554 18.1809C13.8667 18.1809 12.8064 17.9159 11.8747 17.3857C10.959 16.8556 10.2441 16.1166 9.73006 15.1689C9.21601 14.2211 8.95898 13.1287 8.95898 11.8918C8.95898 10.6548 9.21601 9.5624 9.73006 8.6146C10.2441 7.6668 10.959 6.92785 11.8747 6.39772C12.8064 5.8676 13.8667 5.60254 15.0554 5.60254C16.2442 5.60254 17.2964 5.8676 18.212 6.39772C19.1438 6.92785 19.8667 7.6668 20.3807 8.6146C20.8948 9.5624 21.1518 10.6548 21.1518 11.8918C21.1518 13.1287 20.8948 14.2211 20.3807 15.1689C19.8667 16.1166 19.1438 16.8556 18.212 17.3857C17.2964 17.9159 16.2442 18.1809 15.0554 18.1809ZM15.0554 16.4219C15.8586 16.4219 16.5654 16.2291 17.1759 15.8435C17.8023 15.458 18.2844 14.9199 18.6216 14.2291C18.959 13.5383 19.1277 12.7592 19.1277 11.8918C19.1277 11.0242 18.959 10.2451 18.6216 9.55437C18.2844 8.86359 17.8023 8.32545 17.1759 7.9399C16.5654 7.55436 15.8586 7.36159 15.0554 7.36159C14.2521 7.36159 13.5373 7.55436 12.9108 7.9399C12.3004 8.32545 11.8265 8.86359 11.4891 9.55437C11.1518 10.2451 10.9831 11.0242 10.9831 11.8918C10.9831 12.7592 11.1518 13.5383 11.4891 14.2291C11.8265 14.9199 12.3004 15.458 12.9108 15.8435C13.5373 16.2291 14.2521 16.4219 15.0554 16.4219Z" fill="black"/>
|
||||
<path d="M34.6823 5.74712V17.4339C34.6823 21.1448 32.6503 23.0002 28.5859 23.0002C26.9956 23.0002 25.6944 22.6388 24.6823 21.9158C23.6863 21.193 23.108 20.1649 22.9474 18.8315H24.9715C25.1322 19.6025 25.5418 20.197 26.2004 20.6146C26.8591 21.0323 27.7024 21.2411 28.7305 21.2411C31.3811 21.2411 32.7065 19.948 32.7065 17.3617V15.9159C31.823 17.4259 30.4173 18.1809 28.4896 18.1809C27.349 18.1809 26.3289 17.9319 25.4293 17.4339C24.5458 16.9359 23.847 16.213 23.3329 15.2652C22.8349 14.3174 22.5859 13.193 22.5859 11.8918C22.5859 10.6548 22.8349 9.5624 23.3329 8.6146C23.847 7.6668 24.5538 6.92785 25.4534 6.39772C26.3531 5.8676 27.365 5.60254 28.4896 5.60254C29.4855 5.60254 30.337 5.80334 31.0438 6.20495C31.7507 6.5905 32.3049 7.14472 32.7065 7.86761L32.9715 5.74712H34.6823ZM28.6824 16.4219C29.4695 16.4219 30.1683 16.2371 30.7787 15.8677C31.4053 15.4821 31.8872 14.9519 32.2246 14.2772C32.5618 13.5865 32.7306 12.8074 32.7306 11.9399C32.7306 11.0564 32.5618 10.2692 32.2246 9.57846C31.8872 8.87163 31.4053 8.32545 30.7787 7.9399C30.1683 7.55436 29.4695 7.36159 28.6824 7.36159C27.4615 7.36159 26.4735 7.78729 25.7185 8.63869C24.9795 9.47404 24.61 10.5584 24.61 11.8918C24.61 13.2251 24.9795 14.3174 25.7185 15.1689C26.4735 16.0042 27.4615 16.4219 28.6824 16.4219Z" fill="black"/>
|
||||
<path d="M36.5449 11.8918C36.5449 10.6387 36.7859 9.5383 37.2678 8.5905C37.7658 7.64271 38.4565 6.91179 39.3401 6.39772C40.2236 5.8676 41.2357 5.60254 42.3763 5.60254C43.5007 5.60254 44.4968 5.83547 45.3642 6.30133C46.2317 6.7672 46.9144 7.4419 47.4124 8.32545C47.9104 9.20898 48.1755 10.2451 48.2076 11.4339C48.2076 11.6106 48.1915 11.8918 48.1594 12.2772H38.6172V12.446C38.6493 13.6507 39.0187 14.6146 39.7256 15.3375C40.4324 16.0605 41.3562 16.4219 42.4967 16.4219C43.3802 16.4219 44.1272 16.205 44.7377 15.7712C45.3642 15.3215 45.7818 14.703 45.9908 13.9158H47.9907C47.7497 15.1689 47.1473 16.197 46.1834 17.0001C45.2196 17.7873 44.0389 18.1809 42.6412 18.1809C41.4204 18.1809 40.3521 17.9239 39.4365 17.4098C38.5208 16.8797 37.806 16.1408 37.2919 15.1929C36.7939 14.2291 36.5449 13.1287 36.5449 11.8918ZM46.1594 10.6387C46.063 9.59452 45.6694 8.78328 44.9787 8.20496C44.304 7.62664 43.4445 7.33749 42.4003 7.33749C41.4686 7.33749 40.6493 7.64271 39.9425 8.25315C39.2357 8.86359 38.8341 9.65878 38.7376 10.6387H46.1594Z" fill="black"/>
|
||||
<path d="M50.7442 7.55431H48.5273V5.74708H50.7442V2.30127H52.7201V5.74708H55.8285V7.55431H52.7201V14.8074C52.7201 15.3214 52.8165 15.6909 53.0093 15.9158C53.2181 16.1246 53.5715 16.2291 54.0696 16.2291H56.2141V18.0363H53.9491C52.7924 18.0363 51.9651 17.7792 51.4671 17.2652C50.9851 16.7511 50.7442 15.9398 50.7442 14.8314V7.55431Z" fill="black"/>
|
||||
<path d="M63.2468 5.6027C64.7408 5.6027 65.9456 6.0525 66.8613 6.95211C67.7769 7.8517 68.2348 9.26536 68.2348 11.1931V18.0365H66.2589V11.3136C66.2589 10.0445 65.9697 9.08062 65.3914 8.42199C64.8131 7.74729 63.9858 7.40994 62.9095 7.40994C61.7689 7.40994 60.8613 7.81154 60.1866 8.61476C59.5279 9.41798 59.1986 10.5103 59.1986 11.8919V18.0365H57.2227V1.16895H59.1986V7.77139C59.6002 7.12881 60.1303 6.60672 60.789 6.20511C61.4637 5.8035 62.283 5.6027 63.2468 5.6027Z" fill="black"/>
|
||||
<path d="M69.9258 11.8918C69.9258 10.6387 70.1667 9.5383 70.6486 8.5905C71.1467 7.64271 71.8374 6.91179 72.721 6.39772C73.6045 5.8676 74.6165 5.60254 75.7571 5.60254C76.8816 5.60254 77.8776 5.83547 78.7451 6.30133C79.6126 6.7672 80.2953 7.4419 80.7933 8.32545C81.2912 9.20898 81.5563 10.2451 81.5885 11.4339C81.5885 11.6106 81.5723 11.8918 81.5403 12.2772H71.998V12.446C72.0302 13.6507 72.3996 14.6146 73.1064 15.3375C73.8133 16.0605 74.737 16.4219 75.8776 16.4219C76.7611 16.4219 77.5081 16.205 78.1186 15.7712C78.7451 15.3215 79.1627 14.703 79.3715 13.9158H81.3715C81.1306 15.1689 80.5282 16.197 79.5643 17.0001C78.6005 17.7873 77.4198 18.1809 76.0221 18.1809C74.8012 18.1809 73.733 17.9239 72.8173 17.4098C71.9017 16.8797 71.1868 16.1408 70.6728 15.1929C70.1747 14.2291 69.9258 13.1287 69.9258 11.8918ZM79.5403 10.6387C79.4438 9.59452 79.0502 8.78328 78.3595 8.20496C77.6848 7.62664 76.8254 7.33749 75.7811 7.33749C74.8495 7.33749 74.0302 7.64271 73.3234 8.25315C72.6165 8.86359 72.2149 9.65878 72.1185 10.6387H79.5403Z" fill="black"/>
|
||||
<path d="M89.6864 5.74707V7.67478H88.6984C87.5257 7.67478 86.6823 8.06836 86.1682 8.85551C85.6703 9.64266 85.4212 10.6146 85.4212 11.7712V18.0363H83.4453V5.74707H85.1562L85.4212 7.6025C85.7746 7.04024 86.2325 6.59045 86.7947 6.25309C87.357 5.91575 88.1361 5.74707 89.1321 5.74707H89.6864Z" fill="black"/>
|
||||
<path d="M109.812 16.2291V18.0364H108.726C107.939 18.0364 107.378 17.8757 107.04 17.5543C106.703 17.2331 106.526 16.7592 106.51 16.1327C105.562 17.4982 104.189 18.1809 102.39 18.1809C101.024 18.1809 99.9237 17.8596 99.0883 17.2171C98.269 16.5745 97.8594 15.6989 97.8594 14.5905C97.8594 13.3536 98.2771 12.4058 99.1124 11.7471C99.9637 11.0885 101.193 10.7592 102.799 10.7592H106.414V9.9158C106.414 9.11259 106.14 8.48608 105.594 8.03628C105.064 7.58648 104.317 7.36159 103.353 7.36159C102.502 7.36159 101.795 7.55436 101.233 7.9399C100.687 8.30937 100.349 8.80737 100.221 9.43388H98.2449C98.3894 8.22906 98.9196 7.28929 99.8353 6.61459C100.767 5.93989 101.972 5.60254 103.45 5.60254C105.024 5.60254 106.237 5.98808 107.088 6.75917C107.955 7.5142 108.39 8.60657 108.39 10.0363V15.3375C108.39 15.9319 108.662 16.2291 109.209 16.2291H109.812ZM106.414 12.4218H102.606C100.775 12.4218 99.8594 13.1045 99.8594 14.47C99.8594 15.0805 100.1 15.5704 100.582 15.9399C101.064 16.3094 101.715 16.4942 102.534 16.4942C103.739 16.4942 104.687 16.1809 105.377 15.5544C106.068 14.9118 106.414 14.0684 106.414 13.0242V12.4218Z" fill="black"/>
|
||||
<path d="M111.922 1C112.291 1 112.597 1.12048 112.837 1.36145C113.079 1.60241 113.199 1.90763 113.199 2.27711C113.199 2.64659 113.079 2.95182 112.837 3.19278C112.597 3.43374 112.291 3.55423 111.922 3.55423C111.552 3.55423 111.247 3.43374 111.007 3.19278C110.765 2.95182 110.645 2.64659 110.645 2.27711C110.645 1.90763 110.765 1.60241 111.007 1.36145C111.247 1.12048 111.552 1 111.922 1ZM110.934 5.74701H112.91V18.0362H110.934V5.74701Z" fill="black"/>
|
||||
<path d="M93.9949 16.1652C93.9949 17.1986 93.1469 18.0364 92.1009 18.0364C91.055 18.0364 90.207 17.1986 90.207 16.1652C90.207 15.1317 91.055 14.2939 92.1009 14.2939C93.1469 14.2939 93.9949 15.1317 93.9949 16.1652Z" fill="#0F6FFF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 7.8 KiB |
@@ -0,0 +1,19 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_15960_46917)">
|
||||
<mask id="mask0_15960_46917" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="0" width="16" height="16">
|
||||
<path d="M16 0H0V16H16V0Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_15960_46917)">
|
||||
<path d="M13.1765 0H2.82353C1.26414 0 0 1.26414 0 2.82353V13.1765C0 14.7359 1.26414 16 2.82353 16H13.1765C14.7359 16 16 14.7359 16 13.1765V2.82353C16 1.26414 14.7359 0 13.1765 0Z" fill="#F1EFED"/>
|
||||
<path d="M11.4119 7.64706C12.9713 7.64706 14.2354 6.38292 14.2354 4.82353C14.2354 3.26414 12.9713 2 11.4119 2C9.85252 2 8.58838 3.26414 8.58838 4.82353C8.58838 6.38292 9.85252 7.64706 11.4119 7.64706Z" fill="#D3D1D1"/>
|
||||
<path d="M11.4119 14.2354C12.9713 14.2354 14.2354 12.9713 14.2354 11.4119C14.2354 9.85252 12.9713 8.58838 11.4119 8.58838C9.85252 8.58838 8.58838 9.85252 8.58838 11.4119C8.58838 12.9713 9.85252 14.2354 11.4119 14.2354Z" fill="#D3D1D1"/>
|
||||
<path d="M4.82353 14.2354C6.38292 14.2354 7.64706 12.9713 7.64706 11.4119C7.64706 9.85252 6.38292 8.58838 4.82353 8.58838C3.26414 8.58838 2 9.85252 2 11.4119C2 12.9713 3.26414 14.2354 4.82353 14.2354Z" fill="#D3D1D1"/>
|
||||
<path d="M4.82353 7.64706C6.38292 7.64706 7.64706 6.38292 7.64706 4.82353C7.64706 3.26414 6.38292 2 4.82353 2C3.26414 2 2 3.26414 2 4.82353C2 6.38292 3.26414 7.64706 4.82353 7.64706Z" fill="#0F6FFF"/>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_15960_46917">
|
||||
<rect width="16" height="16" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
45
api/core/model_runtime/model_providers/togetherai/llm/llm.py
Normal file
45
api/core/model_runtime/model_providers/togetherai/llm/llm.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Generator, List, Optional, Union
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _update_endpoint_url(self, credentials: dict):
|
||||
credentials['endpoint_url'] = "https://api.together.xyz/v1"
|
||||
return credentials
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super().validate_credentials(model, cred_with_endpoint)
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super().get_customizable_model_schema(model, cred_with_endpoint)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
||||
|
||||
return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)
|
||||
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TogetherAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,75 @@
|
||||
provider: togetherai
|
||||
label:
|
||||
en_US: together.ai
|
||||
icon_small:
|
||||
en_US: togetherai_square.svg
|
||||
icon_large:
|
||||
en_US: togetherai.svg
|
||||
background: "#F1EFED"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from together.ai
|
||||
zh_Hans: 从 together.ai 获取 API Key
|
||||
url:
|
||||
en_US: https://api.together.xyz/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter full model name
|
||||
zh_Hans: 输入模型全称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens_to_sample
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
@@ -52,9 +52,13 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
response = dashscope.Tokenization.call(
|
||||
model=model,
|
||||
prompt=self._convert_messages_to_prompt(prompt_messages),
|
||||
**credentials_kwargs
|
||||
)
|
||||
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
@@ -108,10 +112,6 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
dashscope.api_key = credentials_kwargs['api_key']
|
||||
|
||||
print(credentials_kwargs, 'credentials_kwargs')
|
||||
|
||||
client = EnhanceTongyi(
|
||||
model_name=model,
|
||||
streaming=stream,
|
||||
@@ -121,7 +121,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': self._convert_messages_to_prompt(prompt_messages),
|
||||
**model_parameters
|
||||
**model_parameters,
|
||||
**credentials_kwargs
|
||||
}
|
||||
if stream:
|
||||
responses = stream_generate_with_retry(
|
||||
@@ -222,7 +223,6 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
print(credentials, 'credentials')
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['dashscope_api_key'],
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, SystemPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, ParameterRule, ParameterType, ModelPropertyKey
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.xinference.llm.xinference_helper import XinferenceHelper, XinferenceModelExtraParameter
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
@@ -56,10 +56,18 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
XinferenceHelper.get_xinference_extra_parameter(
|
||||
extra_param = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
if 'completion_type' not in credentials:
|
||||
if 'chat' in extra_param.model_ability:
|
||||
credentials['completion_type'] = 'chat'
|
||||
elif 'generate' in extra_param.model_ability:
|
||||
credentials['completion_type'] = 'completion'
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
||||
|
||||
except RuntimeError as e:
|
||||
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
||||
except KeyError as e:
|
||||
@@ -256,17 +264,26 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
]
|
||||
|
||||
completion_type = None
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
completion_type = LLMMode.CHAT
|
||||
elif 'generate' in extra_args.model_ability:
|
||||
completion_type = LLMMode.COMPLETION
|
||||
if 'completion_type' in credentials:
|
||||
if credentials['completion_type'] == 'chat':
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif credentials['completion_type'] == 'completion':
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
else:
|
||||
raise NotImplementedError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials['server_url'],
|
||||
model_uid=credentials['model_uid']
|
||||
)
|
||||
|
||||
if 'chat' in extra_args.model_ability:
|
||||
completion_type = LLMMode.CHAT.value
|
||||
elif 'generate' in extra_args.model_ability:
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
@@ -276,7 +293,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
model_properties={
|
||||
'mode': completion_type,
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
},
|
||||
parameter_rules=rules
|
||||
)
|
||||
|
||||
@@ -8,8 +8,9 @@ from typing import (
|
||||
Union
|
||||
)
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \
|
||||
SystemPromptMessage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, \
|
||||
AssistantPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageRole
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -111,16 +112,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if len(prompt_messages) == 0:
|
||||
raise ValueError('At least one message is required')
|
||||
|
||||
if prompt_messages[0].role.value == 'system':
|
||||
if prompt_messages[0].role == PromptMessageRole.SYSTEM:
|
||||
if not prompt_messages[0].content:
|
||||
prompt_messages = prompt_messages[1:]
|
||||
|
||||
# resolve zhipuai model not support system message and user message, assistant message must be in sequence
|
||||
new_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
copy_prompt_message = prompt_message.copy()
|
||||
if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]:
|
||||
if not isinstance(copy_prompt_message.content, str):
|
||||
# not support image message
|
||||
continue
|
||||
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
else:
|
||||
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
||||
new_prompt_messages.append(new_prompt_message)
|
||||
else:
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT:
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{
|
||||
'role': prompt_message.role.value if prompt_message.role.value != 'system' else 'user',
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content
|
||||
} for prompt_message in prompt_messages],
|
||||
} for prompt_message in new_prompt_messages],
|
||||
**model_parameters
|
||||
}
|
||||
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
|
||||
class CloudServiceModeration(Moderation):
|
||||
"""
|
||||
The name of custom type must be unique, keep the same with directory and file name.
|
||||
"""
|
||||
name: str = "cloud_service"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
schema.json validation. It will be called when user save the config.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
config = {
|
||||
"cloud_provider": "GoogleCloud",
|
||||
"api_endpoint": "https://api.example.com",
|
||||
"api_keys": "123456",
|
||||
"inputs_config": {
|
||||
"enabled": True,
|
||||
"preset_response": "Your content violates our usage policy. Please revise and try again."
|
||||
},
|
||||
"outputs_config": {
|
||||
"enabled": True,
|
||||
"preset_response": "Your content violates our usage policy. Please revise and try again."
|
||||
}
|
||||
}
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the variables of form config
|
||||
:return:
|
||||
"""
|
||||
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
if not config.get("cloud_provider"):
|
||||
raise ValueError("cloud_provider is required")
|
||||
|
||||
if not config.get("api_endpoint"):
|
||||
raise ValueError("api_endpoint is required")
|
||||
|
||||
if not config.get("api_keys"):
|
||||
raise ValueError("api_keys is required")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
"""
|
||||
Moderation for inputs.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app, there is empty if is completion app
|
||||
:return: the moderation result
|
||||
"""
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['inputs_config']['enabled']:
|
||||
preset_response = self.config['inputs_config']['preset_response']
|
||||
|
||||
if query:
|
||||
inputs['query__'] = query
|
||||
flagged = self._is_violated(inputs)
|
||||
|
||||
# return ModerationInputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, inputs=inputs, query=query)
|
||||
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
"""
|
||||
Moderation for outputs.
|
||||
|
||||
:param text: the text of LLM response
|
||||
:return: the moderation result
|
||||
"""
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
if self.config['outputs_config']['enabled']:
|
||||
preset_response = self.config['outputs_config']['preset_response']
|
||||
|
||||
flagged = self._is_violated({'text': text})
|
||||
|
||||
# return ModerationOutputsResult(flagged=flagged, action=ModerationAction.OVERRIDED, text=text)
|
||||
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||
|
||||
def _is_violated(self, inputs: dict):
|
||||
"""
|
||||
The main logic of moderation.
|
||||
|
||||
:param inputs:
|
||||
:return: the moderation result
|
||||
"""
|
||||
return False
|
||||
@@ -1,65 +0,0 @@
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Cloud Service",
|
||||
"zh-Hans": "云服务"
|
||||
},
|
||||
"form_schema": [
|
||||
{
|
||||
"type": "select",
|
||||
"label": {
|
||||
"en-US": "Cloud Provider",
|
||||
"zh-Hans": "云厂商"
|
||||
},
|
||||
"variable": "cloud_provider",
|
||||
"required": true,
|
||||
"options": [
|
||||
{
|
||||
"label": {
|
||||
"en-US": "AWS",
|
||||
"zh-Hans": "亚马逊"
|
||||
},
|
||||
"value": "AWS"
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Google Cloud",
|
||||
"zh-Hans": "谷歌云"
|
||||
},
|
||||
"value": "GoogleCloud"
|
||||
},
|
||||
{
|
||||
"label": {
|
||||
"en-US": "Azure Cloud",
|
||||
"zh-Hans": "微软云"
|
||||
},
|
||||
"value": "Azure"
|
||||
}
|
||||
],
|
||||
"default": "GoogleCloud",
|
||||
"placeholder": ""
|
||||
},
|
||||
{
|
||||
"type": "text-input",
|
||||
"label": {
|
||||
"en-US": "API Endpoint",
|
||||
"zh-Hans": "API Endpoint"
|
||||
},
|
||||
"variable": "api_endpoint",
|
||||
"required": true,
|
||||
"max_length": 100,
|
||||
"default": "",
|
||||
"placeholder": "https://api.example.com"
|
||||
},
|
||||
{
|
||||
"type": "paragraph",
|
||||
"label": {
|
||||
"en-US": "API Key",
|
||||
"zh-Hans": "API Key"
|
||||
},
|
||||
"variable": "api_keys",
|
||||
"required": true,
|
||||
"default": "",
|
||||
"placeholder": "Paste your API key here"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -207,7 +207,7 @@ class PromptTransform:
|
||||
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
with open(json_file_path, 'r', encoding='utf-8') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict,
|
||||
@@ -334,7 +334,18 @@ class PromptTransform:
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
|
||||
if model_mode == ModelMode.CHAT and files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
prompt_message_contents.append(file.prompt_message_content)
|
||||
|
||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
prompt_message = UserPromptMessage(content=prompt)
|
||||
|
||||
return [prompt_message]
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
|
||||
if '#context#' in prompt_template.variable_keys:
|
||||
|
||||
@@ -75,7 +75,7 @@ GENERATOR_QA_PROMPT = (
|
||||
'Step 3: Decompose or combine multiple pieces of information and concepts.\n'
|
||||
'Step 4: Generate 20 questions and answers based on these key information and concepts.'
|
||||
'The questions should be clear and detailed, and the answers should be detailed and complete.\n'
|
||||
"Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
"Answer MUST according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n"
|
||||
)
|
||||
|
||||
RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
|
||||
SystemConfiguration, QuotaConfiguration
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
@@ -23,6 +24,9 @@ class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@@ -79,9 +83,6 @@ class ProviderManager:
|
||||
# Get All preferred provider types of the workspace
|
||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
provider_configurations = ProviderConfigurations(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
@@ -100,19 +101,17 @@ class ProviderManager:
|
||||
|
||||
# Convert to custom configuration
|
||||
custom_configuration = self._to_custom_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records,
|
||||
provider_model_records,
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
provider_model_records
|
||||
)
|
||||
|
||||
# Convert to system configuration
|
||||
system_configuration = self._to_system_configuration(
|
||||
tenant_id,
|
||||
provider_entity,
|
||||
provider_records,
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
provider_records
|
||||
)
|
||||
|
||||
# Get preferred provider type
|
||||
@@ -233,11 +232,18 @@ class ProviderManager:
|
||||
return None
|
||||
|
||||
provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
return DefaultModelEntity(
|
||||
model=default_model.model_name,
|
||||
model_type=model_type,
|
||||
provider=DefaultModelProviderEntity(**provider_instance.get_provider_schema().to_simple_provider().dict())
|
||||
provider=DefaultModelProviderEntity(
|
||||
provider=provider_schema.provider,
|
||||
label=provider_schema.label,
|
||||
icon_small=provider_schema.icon_small,
|
||||
icon_large=provider_schema.icon_large,
|
||||
supported_model_types=provider_schema.supported_model_types
|
||||
)
|
||||
)
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
@@ -401,28 +407,29 @@ class ProviderManager:
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.is_valid == True
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||
).first()
|
||||
|
||||
if provider_record and not provider_record.is_valid:
|
||||
provider_record.is_valid = True
|
||||
db.session.commit()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
||||
|
||||
return provider_name_to_provider_records_dict
|
||||
|
||||
def _to_custom_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
provider_model_records: list[ProviderModel],
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa) -> CustomConfiguration:
|
||||
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
|
||||
"""
|
||||
Convert to custom configuration.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_entity: provider entity
|
||||
:param provider_records: provider records
|
||||
:param provider_model_records: provider model records
|
||||
:param decoding_rsa_key: decoding rsa key
|
||||
:param decoding_cipher_rsa: decoding cipher rsa
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
@@ -445,28 +452,49 @@ class ProviderManager:
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
try:
|
||||
# fix origin data
|
||||
if (custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||
provider_credentials = {
|
||||
"openai_api_key": custom_provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if (custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||
provider_credentials = {
|
||||
"openai_api_key": custom_provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=provider_credentials
|
||||
)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials
|
||||
@@ -484,21 +512,42 @@ class ProviderManager:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_model_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
try:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(
|
||||
credentials=provider_model_credentials
|
||||
)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
@@ -514,17 +563,15 @@ class ProviderManager:
|
||||
)
|
||||
|
||||
def _to_system_configuration(self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_records: list[Provider],
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa) -> SystemConfiguration:
|
||||
provider_records: list[Provider]) -> SystemConfiguration:
|
||||
"""
|
||||
Convert to system configuration.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_entity: provider entity
|
||||
:param provider_records: provider records
|
||||
:param decoding_rsa_key: decoding rsa key
|
||||
:param decoding_cipher_rsa: decoding cipher rsa
|
||||
:return:
|
||||
"""
|
||||
# Get hosting configuration
|
||||
@@ -577,29 +624,50 @@ class ProviderManager:
|
||||
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
||||
|
||||
if provider_record:
|
||||
try:
|
||||
provider_credentials = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
|
||||
current_using_credentials = provider_credentials
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
provider_credentials = json.loads(provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema else []
|
||||
)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
current_using_credentials = provider_credentials
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(
|
||||
credentials=current_using_credentials
|
||||
)
|
||||
else:
|
||||
current_using_credentials = cached_provider_credentials
|
||||
else:
|
||||
current_using_credentials = {}
|
||||
|
||||
|
||||
@@ -46,11 +46,11 @@ def init_app(app: Flask) -> Celery:
|
||||
beat_schedule = {
|
||||
'clean_embedding_cache_task': {
|
||||
'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task',
|
||||
'schedule': timedelta(minutes=1),
|
||||
'schedule': timedelta(days=7),
|
||||
},
|
||||
'clean_unused_datasets_task': {
|
||||
'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task',
|
||||
'schedule': timedelta(minutes=10),
|
||||
'schedule': timedelta(days=7),
|
||||
}
|
||||
}
|
||||
celery_app.conf.update(
|
||||
|
||||
@@ -5,7 +5,6 @@ from Crypto.Cipher import PKCS1_OAEP, AES
|
||||
from Crypto.PublicKey import RSA
|
||||
from Crypto.Random import get_random_bytes
|
||||
|
||||
from core.helper.lru_cache import LRUCache
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@@ -46,15 +45,7 @@ def encrypt(text, public_key):
|
||||
return prefix_hybrid + encrypted_data
|
||||
|
||||
|
||||
tenant_rsa_keys = LRUCache(capacity=1000)
|
||||
|
||||
|
||||
def get_decrypt_decoding(tenant_id):
|
||||
rsa_key = tenant_rsa_keys.get(tenant_id)
|
||||
if rsa_key:
|
||||
cipher_rsa = PKCS1_OAEP.new(rsa_key)
|
||||
return rsa_key, cipher_rsa
|
||||
|
||||
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
|
||||
|
||||
cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
|
||||
@@ -70,8 +61,6 @@ def get_decrypt_decoding(tenant_id):
|
||||
rsa_key = RSA.import_key(private_key)
|
||||
cipher_rsa = PKCS1_OAEP.new(rsa_key)
|
||||
|
||||
tenant_rsa_keys.put(tenant_id, rsa_key)
|
||||
|
||||
return rsa_key, cipher_rsa
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
|
||||
from models.provider import ProviderType
|
||||
from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
|
||||
SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
|
||||
DefaultModelResponse, ModelWithProviderEntityResponse
|
||||
DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,7 +45,17 @@ class ModelProviderService:
|
||||
continue
|
||||
|
||||
provider_response = ProviderResponse(
|
||||
**provider_configuration.provider.dict(),
|
||||
provider=provider_configuration.provider.provider,
|
||||
label=provider_configuration.provider.label,
|
||||
description=provider_configuration.provider.description,
|
||||
icon_small=provider_configuration.provider.icon_small,
|
||||
icon_large=provider_configuration.provider.icon_large,
|
||||
background=provider_configuration.provider.background,
|
||||
help=provider_configuration.provider.help,
|
||||
supported_model_types=provider_configuration.provider.supported_model_types,
|
||||
configurate_methods=provider_configuration.provider.configurate_methods,
|
||||
provider_credential_schema=provider_configuration.provider.provider_credential_schema,
|
||||
model_credential_schema=provider_configuration.provider.model_credential_schema,
|
||||
preferred_provider_type=provider_configuration.preferred_provider_type,
|
||||
custom_configuration=CustomConfigurationResponse(
|
||||
status=CustomConfigurationStatus.ACTIVE
|
||||
@@ -53,7 +63,9 @@ class ModelProviderService:
|
||||
else CustomConfigurationStatus.NO_CONFIGURE
|
||||
),
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
**provider_configuration.system_configuration.dict()
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
current_quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
quota_configurations=provider_configuration.system_configuration.quota_configurations
|
||||
)
|
||||
)
|
||||
|
||||
@@ -369,7 +381,15 @@ class ModelProviderService:
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
**result.dict()
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
|
||||
@@ -27,7 +27,7 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
raise NotFound('Segment not found')
|
||||
|
||||
if segment.status != 'completed':
|
||||
return
|
||||
raise NotFound('Segment is not completed , disable action is not allowed.')
|
||||
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
raise NotFound('Segment not found')
|
||||
|
||||
if segment.status != 'completed':
|
||||
return
|
||||
raise NotFound('Segment is not completed, enable action is not allowed.')
|
||||
|
||||
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@
|
||||
<p>Dear {{ to }},</p>
|
||||
<p>{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
|
||||
<p>You can now log in to Dify using the GitHub or Google account associated with this email.</p>
|
||||
<p style="text-align: center;"><a class="button" href="{{ url }}">Login Here</a></p>
|
||||
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>Best regards,</p>
|
||||
|
||||
@@ -60,7 +60,7 @@
|
||||
<p>尊敬的 {{ to }},</p>
|
||||
<p>{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
|
||||
<p>您现在可以使用与此邮件相对应的 GitHub 或 Google 账号登录 Dify。</p>
|
||||
<p style="text-align: center;"><a class="button" href="{{ url }}">在此登录</a></p>
|
||||
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>此致,</p>
|
||||
|
||||
@@ -39,13 +39,15 @@ def test_invoke_model(setup_openai_mock):
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_validate_credentials():
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
@@ -31,7 +31,7 @@ def test_validate_credentials():
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
@@ -43,7 +43,7 @@ def test_invoke_model():
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/completions',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'completion'
|
||||
},
|
||||
prompt_messages=[
|
||||
@@ -74,7 +74,7 @@ def test_invoke_stream_model():
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
@@ -110,7 +110,7 @@ def test_invoke_chat_model_with_tools():
|
||||
model='gpt-3.5-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/chat/completions',
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
@@ -165,7 +165,7 @@ def test_get_num_tokens():
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/chat/completions'
|
||||
'endpoint_url': 'https://api.openai.com/v1/'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
|
||||
@@ -18,9 +18,8 @@ def test_validate_credentials():
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
|
||||
}
|
||||
)
|
||||
@@ -29,9 +28,8 @@ def test_validate_credentials():
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
}
|
||||
)
|
||||
|
||||
@@ -43,20 +41,21 @@ def test_invoke_model():
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 502
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
@@ -67,8 +66,7 @@ def test_get_num_tokens():
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
'context_size': 8184
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
|
||||
117
api/tests/integration_tests/model_runtime/togetherai/test_llm.py
Normal file
117
api/tests/integration_tests/model_runtime/togetherai/test_llm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.4.2
|
||||
image: langgenius/dify-api:0.4.4
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -92,6 +92,8 @@ services:
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# The Qdrant clinet timeout setting.
|
||||
QDRANT_CLIENT_TIMEOUT: 20
|
||||
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
|
||||
# The milvus host.
|
||||
MILVUS_HOST: 127.0.0.1
|
||||
@@ -128,7 +130,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.4.2
|
||||
image: langgenius/dify-api:0.4.4
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -170,6 +172,8 @@ services:
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# The Qdrant clinet timeout setting.
|
||||
QDRANT_CLIENT_TIMEOUT: 20
|
||||
# Milvus configuration Only available when VECTOR_STORE is `milvus`.
|
||||
# The milvus host.
|
||||
MILVUS_HOST: 127.0.0.1
|
||||
@@ -196,7 +200,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.4.2
|
||||
image: langgenius/dify-web:0.4.4
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
]
|
||||
}
|
||||
],
|
||||
"react-hooks/exhaustive-deps": "warn"
|
||||
"react-hooks/exhaustive-deps": "warn",
|
||||
"react/display-name": "warn"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
|
||||
import { useTranslation } from "react-i18next"
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const DatasetFooter = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
@@ -10,4 +10,4 @@ const TextGeneration: FC<IMainProps> = () => {
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(TextGeneration)
|
||||
export default React.memo(TextGeneration)
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import s from './style.module.css'
|
||||
|
||||
export interface ILoaidingAnimProps {
|
||||
export type ILoaidingAnimProps = {
|
||||
type: 'text' | 'avatar'
|
||||
}
|
||||
|
||||
const LoaidingAnim: FC<ILoaidingAnimProps> = ({
|
||||
type
|
||||
type,
|
||||
}) => {
|
||||
return (
|
||||
<div className={`${s['dot-flashing']} ${s[type]}`}></div>
|
||||
|
||||
@@ -23,7 +23,6 @@ const style = {
|
||||
overflow: 'auto',
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react/display-name
|
||||
const Flowchart = React.forwardRef((props: {
|
||||
PrimitiveCode: string
|
||||
}, ref) => {
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
export interface IGroupNameProps {
|
||||
export type IGroupNameProps = {
|
||||
name: string
|
||||
}
|
||||
|
||||
const GroupName: FC<IGroupNameProps> = ({
|
||||
name
|
||||
name,
|
||||
}) => {
|
||||
return (
|
||||
<div className='flex items-center mb-1'>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
const MoreLikeThisIcon: FC = ({ }) => {
|
||||
const MoreLikeThisIcon: FC = () => {
|
||||
return (
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fillRule="evenodd" clipRule="evenodd" d="M5.83914 0.666748H10.1609C10.6975 0.666741 11.1404 0.666734 11.5012 0.696212C11.8759 0.726829 12.2204 0.792538 12.544 0.957399C13.0457 1.21306 13.4537 1.62101 13.7093 2.12277C13.8742 2.44633 13.9399 2.7908 13.9705 3.16553C14 3.52633 14 3.96923 14 4.50587V7.41171C14 7.62908 14 7.73776 13.9652 7.80784C13.9303 7.87806 13.8939 7.91566 13.8249 7.95288C13.756 7.99003 13.6262 7.99438 13.3665 8.00307C12.8879 8.01909 12.4204 8.14633 11.997 8.36429C10.9478 7.82388 9.62021 7.82912 8.53296 8.73228C7.15064 9.88056 6.92784 11.8645 8.0466 13.2641C8.36602 13.6637 8.91519 14.1949 9.40533 14.6492C9.49781 14.7349 9.54405 14.7777 9.5632 14.8041C9.70784 15.003 9.5994 15.2795 9.35808 15.3271C9.32614 15.3334 9.26453 15.3334 9.14129 15.3334H5.83912C5.30248 15.3334 4.85958 15.3334 4.49878 15.304C4.12405 15.2733 3.77958 15.2076 3.45603 15.0428C2.95426 14.7871 2.54631 14.3792 2.29065 13.8774C2.12579 13.5538 2.06008 13.2094 2.02946 12.8346C1.99999 12.4738 1.99999 12.0309 2 11.4943V4.50587C1.99999 3.96924 1.99999 3.52632 2.02946 3.16553C2.06008 2.7908 2.12579 2.44633 2.29065 2.12277C2.54631 1.62101 2.95426 1.21306 3.45603 0.957399C3.77958 0.792538 4.12405 0.726829 4.49878 0.696212C4.85957 0.666734 5.3025 0.666741 5.83914 0.666748ZM4.66667 5.33342C4.29848 5.33342 4 5.63189 4 6.00008C4 6.36827 4.29848 6.66675 4.66667 6.66675H8.66667C9.03486 6.66675 9.33333 6.36827 9.33333 6.00008C9.33333 5.63189 9.03486 5.33342 8.66667 5.33342H4.66667ZM4 8.66675C4 8.29856 4.29848 8.00008 4.66667 8.00008H6C6.36819 8.00008 6.66667 8.29856 6.66667 8.66675C6.66667 9.03494 6.36819 9.33342 6 9.33342H4.66667C4.29848 9.33342 4 9.03494 4 8.66675ZM4.66667 2.66675C4.29848 2.66675 4 2.96523 4 3.33342C4 3.7016 4.29848 4.00008 4.66667 4.00008H10.6667C11.0349 4.00008 11.3333 3.7016 11.3333 3.33342C11.3333 2.96523 11.0349 2.66675 10.6667 2.66675H4.66667Z" fill="#DD2590" />
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { PlusIcon } from '@heroicons/react/20/solid'
|
||||
|
||||
export interface IOperationBtnProps {
|
||||
export type IOperationBtnProps = {
|
||||
type: 'add' | 'edit'
|
||||
actionName?: string
|
||||
onClick: () => void
|
||||
@@ -14,13 +15,13 @@ const iconMap = {
|
||||
edit: (<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M6.99998 11.6666H12.25M1.75 11.6666H2.72682C3.01217 11.6666 3.15485 11.6666 3.28912 11.6344C3.40816 11.6058 3.52196 11.5587 3.62635 11.4947C3.74408 11.4226 3.84497 11.3217 4.04675 11.1199L11.375 3.79164C11.8583 3.30839 11.8583 2.52488 11.375 2.04164C10.8918 1.55839 10.1083 1.55839 9.62501 2.04164L2.29674 9.3699C2.09496 9.57168 1.99407 9.67257 1.92192 9.7903C1.85795 9.89469 1.81081 10.0085 1.78224 10.1275C1.75 10.2618 1.75 10.4045 1.75 10.6898V11.6666Z" stroke="#344054" strokeWidth="1.25" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
const OperationBtn: FC<IOperationBtnProps> = ({
|
||||
type,
|
||||
actionName,
|
||||
onClick
|
||||
onClick,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
import s from './style.module.css'
|
||||
|
||||
export interface IVarHighlightProps {
|
||||
export type IVarHighlightProps = {
|
||||
name: string
|
||||
}
|
||||
|
||||
@@ -31,6 +32,4 @@ export const varHighlightHTML = ({ name }: IVarHighlightProps) => {
|
||||
return html
|
||||
}
|
||||
|
||||
|
||||
|
||||
export default React.memo(VarHighlight)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import WarningMask from '.'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
export interface IHasNotSetAPIProps {
|
||||
export type IHasNotSetAPIProps = {
|
||||
isTrailFinished: boolean
|
||||
onSetting: () => void
|
||||
}
|
||||
@@ -18,7 +19,7 @@ const icon = (
|
||||
|
||||
const HasNotSetAPI: FC<IHasNotSetAPIProps> = ({
|
||||
isTrailFinished,
|
||||
onSetting
|
||||
onSetting,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
import s from './style.module.css'
|
||||
|
||||
export interface IWarningMaskProps {
|
||||
export type IWarningMaskProps = {
|
||||
title: string
|
||||
description: string
|
||||
footer: React.ReactNode
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect, useState } from 'react'
|
||||
import cn from 'classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useBoolean, useClickAway, useGetState } from 'ahooks'
|
||||
import { InformationCircleIcon } from '@heroicons/react/24/outline'
|
||||
import produce from 'immer'
|
||||
import ParamItem from './param-item'
|
||||
import { SlidersH } from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
|
||||
import Radio from '@/app/components/base/radio'
|
||||
import Panel from '@/app/components/base/panel'
|
||||
import type { CompletionParams } from '@/models/debug'
|
||||
import { TONE_LIST } from '@/config'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import { Brush01 } from '@/app/components/base/icons/src/vender/solid/editor'
|
||||
import { Scales02 } from '@/app/components/base/icons/src/vender/solid/FinanceAndECommerce'
|
||||
import { Target04 } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { Sliders02 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
|
||||
import { fetchModelParams } from '@/service/debug'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import type { ModelModeType } from '@/types/app'
|
||||
import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon'
|
||||
import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
|
||||
export type IConfigModelProps = {
|
||||
isAdvancedMode: boolean
|
||||
mode: string
|
||||
modelId: string
|
||||
provider: string
|
||||
setModel: (model: { id: string; provider: string; mode: ModelModeType; features: string[] }) => void
|
||||
completionParams: CompletionParams
|
||||
onCompletionParamsChange: (newParams: CompletionParams) => void
|
||||
disabled: boolean
|
||||
}
|
||||
|
||||
const ConfigModel: FC<IConfigModelProps> = ({
|
||||
isAdvancedMode,
|
||||
modelId,
|
||||
provider,
|
||||
setModel,
|
||||
completionParams,
|
||||
onCompletionParamsChange,
|
||||
disabled,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [isShowConfig, { setFalse: hideConfig, toggle: toogleShowConfig }] = useBoolean(false)
|
||||
const [maxTokenSettingTipVisible, setMaxTokenSettingTipVisible] = useState(false)
|
||||
const configContentRef = React.useRef(null)
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel: currModel,
|
||||
textGenerationModelList,
|
||||
} = useTextGenerationCurrentProviderAndModelAndModelList(
|
||||
{ provider, model: modelId },
|
||||
)
|
||||
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
|
||||
// Cache loaded model param
|
||||
const [allParams, setAllParams, getAllParams] = useGetState<Record<string, Record<string, any>>>({})
|
||||
const currParams = allParams[provider]?.[modelId]
|
||||
const hasEnableParams = currParams && Object.keys(currParams).some(key => currParams[key].enabled)
|
||||
const allSupportParams = ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty', 'max_tokens']
|
||||
const currSupportParams = currParams ? allSupportParams.filter(key => currParams[key].enabled) : allSupportParams
|
||||
if (isAdvancedMode)
|
||||
currSupportParams.push('stop')
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
if (!allParams[provider]?.[modelId]) {
|
||||
const res = await fetchModelParams(provider, modelId)
|
||||
const newAllParams = produce(allParams, (draft) => {
|
||||
if (!draft[provider])
|
||||
draft[provider] = {}
|
||||
|
||||
draft[provider][modelId] = res
|
||||
})
|
||||
setAllParams(newAllParams)
|
||||
}
|
||||
})()
|
||||
}, [provider, modelId, allParams, setAllParams])
|
||||
|
||||
useClickAway(() => {
|
||||
hideConfig()
|
||||
}, configContentRef)
|
||||
|
||||
const selectedModel = { name: modelId } // options.find(option => option.id === modelId)
|
||||
|
||||
const ensureModelParamLoaded = (provider: string, modelId: string) => {
|
||||
return new Promise<void>((resolve) => {
|
||||
if (getAllParams()[provider]?.[modelId]) {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
const runId = setInterval(() => {
|
||||
if (getAllParams()[provider]?.[modelId]) {
|
||||
resolve()
|
||||
clearInterval(runId)
|
||||
}
|
||||
}, 500)
|
||||
})
|
||||
}
|
||||
|
||||
const transformValue = (value: number, fromRange: [number, number], toRange: [number, number]): number => {
|
||||
const [fromStart = 0, fromEnd] = fromRange
|
||||
const [toStart = 0, toEnd] = toRange
|
||||
|
||||
// The following three if is to avoid precision loss
|
||||
if (fromStart === toStart && fromEnd === toEnd)
|
||||
return value
|
||||
|
||||
if (value <= fromStart)
|
||||
return toStart
|
||||
|
||||
if (value >= fromEnd)
|
||||
return toEnd
|
||||
|
||||
const fromLength = fromEnd - fromStart
|
||||
const toLength = toEnd - toStart
|
||||
|
||||
let adjustedValue = (value - fromStart) * (toLength / fromLength) + toStart
|
||||
adjustedValue = parseFloat(adjustedValue.toFixed(2))
|
||||
return adjustedValue
|
||||
}
|
||||
|
||||
const handleSelectModel = ({ id, provider: nextProvider, mode, features }: { id: string; provider: string; mode: ModelModeType; features: string[] }) => {
|
||||
return async () => {
|
||||
const prevParamsRule = getAllParams()[provider]?.[modelId]
|
||||
|
||||
setModel({
|
||||
id,
|
||||
provider: nextProvider || 'openai',
|
||||
mode,
|
||||
features,
|
||||
})
|
||||
|
||||
await ensureModelParamLoaded(nextProvider, id)
|
||||
|
||||
const nextParamsRule = getAllParams()[nextProvider]?.[id]
|
||||
// debugger
|
||||
const nextSelectModelMaxToken = nextParamsRule.max_tokens.max
|
||||
const newConCompletionParams = produce(completionParams, (draft: any) => {
|
||||
if (nextParamsRule.max_tokens.enabled) {
|
||||
if (completionParams.max_tokens > nextSelectModelMaxToken) {
|
||||
Toast.notify({
|
||||
type: 'warning',
|
||||
message: t('common.model.params.setToCurrentModelMaxTokenTip', { maxToken: formatNumber(nextSelectModelMaxToken) }),
|
||||
})
|
||||
draft.max_tokens = parseFloat((nextSelectModelMaxToken * 0.8).toFixed(2))
|
||||
}
|
||||
// prev don't have max token
|
||||
if (!completionParams.max_tokens)
|
||||
draft.max_tokens = nextParamsRule.max_tokens.default
|
||||
}
|
||||
else {
|
||||
delete draft.max_tokens
|
||||
}
|
||||
|
||||
allSupportParams.forEach((key) => {
|
||||
if (key === 'max_tokens')
|
||||
return
|
||||
|
||||
if (!nextParamsRule[key].enabled) {
|
||||
delete draft[key]
|
||||
return
|
||||
}
|
||||
|
||||
if (draft[key] === undefined) {
|
||||
draft[key] = nextParamsRule[key].default || 0
|
||||
return
|
||||
}
|
||||
|
||||
if (!prevParamsRule[key].enabled) {
|
||||
draft[key] = nextParamsRule[key].default || 0
|
||||
return
|
||||
}
|
||||
|
||||
draft[key] = transformValue(
|
||||
draft[key],
|
||||
[prevParamsRule[key].min, prevParamsRule[key].max],
|
||||
[nextParamsRule[key].min, nextParamsRule[key].max],
|
||||
)
|
||||
})
|
||||
})
|
||||
onCompletionParamsChange(newConCompletionParams)
|
||||
}
|
||||
}
|
||||
|
||||
// only openai support this
|
||||
function matchToneId(completionParams: CompletionParams): number {
|
||||
const remvoedCustomeTone = TONE_LIST.slice(0, -1)
|
||||
const CUSTOM_TONE_ID = 4
|
||||
const tone = remvoedCustomeTone.find((tone) => {
|
||||
return tone.config?.temperature === completionParams.temperature
|
||||
&& tone.config?.top_p === completionParams.top_p
|
||||
&& tone.config?.presence_penalty === completionParams.presence_penalty
|
||||
&& tone.config?.frequency_penalty === completionParams.frequency_penalty
|
||||
})
|
||||
return tone ? tone.id : CUSTOM_TONE_ID
|
||||
}
|
||||
|
||||
// tone is a preset of completionParams.
|
||||
const [toneId, setToneId] = React.useState(matchToneId(completionParams)) // default is Balanced
|
||||
const toneTabBgClassName = ({
|
||||
1: 'bg-[#F5F8FF]',
|
||||
2: 'bg-[#F4F3FF]',
|
||||
3: 'bg-[#F6FEFC]',
|
||||
})[toneId] || ''
|
||||
// set completionParams by toneId
|
||||
const handleToneChange = (id: number) => {
|
||||
if (id === 4)
|
||||
return // custom tone
|
||||
const tone = TONE_LIST.find(tone => tone.id === id)
|
||||
if (tone) {
|
||||
setToneId(id)
|
||||
onCompletionParamsChange({
|
||||
...tone.config,
|
||||
max_tokens: completionParams.max_tokens,
|
||||
} as CompletionParams)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
setToneId(matchToneId(completionParams))
|
||||
}, [completionParams])
|
||||
|
||||
const handleParamChange = (key: string, value: number | string[]) => {
|
||||
if (value === undefined)
|
||||
return
|
||||
if ((completionParams as any)[key] === value)
|
||||
return
|
||||
|
||||
if (key === 'stop') {
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
[key]: value as string[],
|
||||
})
|
||||
}
|
||||
else {
|
||||
const currParamsRule = getAllParams()[provider]?.[modelId]
|
||||
let notOutRangeValue = parseFloat((value as number).toFixed(2))
|
||||
notOutRangeValue = Math.max(currParamsRule[key].min, notOutRangeValue)
|
||||
notOutRangeValue = Math.min(currParamsRule[key].max, notOutRangeValue)
|
||||
onCompletionParamsChange({
|
||||
...completionParams,
|
||||
[key]: notOutRangeValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
const ableStyle = 'bg-indigo-25 border-[#2A87F5] cursor-pointer'
|
||||
const diabledStyle = 'bg-[#FFFCF5] border-[#F79009]'
|
||||
|
||||
const getToneIcon = (toneId: number) => {
|
||||
const className = 'w-[14px] h-[14px]'
|
||||
const res = ({
|
||||
1: <Brush01 className={className} />,
|
||||
2: <Scales02 className={className} />,
|
||||
3: <Target04 className={className} />,
|
||||
4: <Sliders02 className={className} />,
|
||||
})[toneId]
|
||||
return res
|
||||
}
|
||||
useEffect(() => {
|
||||
if (!currParams)
|
||||
return
|
||||
|
||||
const max = currParams.max_tokens.max
|
||||
const isSupportMaxToken = currParams.max_tokens.enabled
|
||||
if (isSupportMaxToken && currentProvider?.provider !== 'anthropic' && completionParams.max_tokens > max * 2 / 3)
|
||||
setMaxTokenSettingTipVisible(true)
|
||||
else
|
||||
setMaxTokenSettingTipVisible(false)
|
||||
}, [currParams, completionParams.max_tokens, setMaxTokenSettingTipVisible, currentProvider])
|
||||
return (
|
||||
<div className='relative' ref={configContentRef}>
|
||||
<div
|
||||
className={cn('flex items-center border h-8 px-2 space-x-2 rounded-lg', disabled ? diabledStyle : ableStyle)}
|
||||
onClick={() => !disabled && toogleShowConfig()}
|
||||
>
|
||||
{
|
||||
currentProvider && (
|
||||
<ModelIcon
|
||||
className='!w-5 !h-5'
|
||||
provider={currentProvider}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
currModel && (
|
||||
<ModelName
|
||||
className='text-gray-900'
|
||||
modelItem={currModel}
|
||||
showMode={isAdvancedMode}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{disabled ? <InformationCircleIcon className='w-4 h-4 text-[#F79009]' /> : <SlidersH className='w-4 h-4 text-indigo-600' />}
|
||||
</div>
|
||||
{isShowConfig && (
|
||||
<Panel
|
||||
className='absolute z-20 top-8 left-0 sm:left-[unset] sm:right-0 !w-fit sm:!w-[496px] bg-white !overflow-visible shadow-md'
|
||||
keepUnFold
|
||||
headerIcon={
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8.26865 0.790031C8.09143 0.753584 7.90866 0.753584 7.73144 0.790031C7.52659 0.832162 7.3435 0.934713 7.19794 1.01624L7.15826 1.03841L6.17628 1.58395C5.85443 1.76276 5.73846 2.16863 5.91727 2.49049C6.09608 2.81234 6.50195 2.9283 6.82381 2.74949L7.80579 2.20395C7.90681 2.14782 7.95839 2.11946 7.99686 2.10091L8.00004 2.09938L8.00323 2.10091C8.0417 2.11946 8.09327 2.14782 8.1943 2.20395L9.17628 2.74949C9.49814 2.9283 9.90401 2.81234 10.0828 2.49048C10.2616 2.16863 10.1457 1.76276 9.82381 1.58395L8.84183 1.03841L8.80215 1.01624C8.65659 0.934713 8.4735 0.832162 8.26865 0.790031Z" fill="#1C64F2" />
|
||||
<path d="M12.8238 3.25062C12.5019 3.07181 12.0961 3.18777 11.9173 3.50963C11.7385 3.83148 11.8544 4.23735 12.1763 4.41616L12.6272 4.66668L12.1763 4.91719C11.8545 5.096 11.7385 5.50186 11.9173 5.82372C12.0961 6.14558 12.502 6.26154 12.8238 6.08273L13.3334 5.79966V6.33339C13.3334 6.70158 13.6319 7.00006 14 7.00006C14.3682 7.00006 14.6667 6.70158 14.6667 6.33339V5.29435L14.6668 5.24627C14.6673 5.12441 14.6678 4.98084 14.6452 4.83482C14.6869 4.67472 14.6696 4.49892 14.5829 4.34286C14.4904 4.1764 14.3371 4.06501 14.1662 4.02099C14.0496 3.93038 13.9239 3.86116 13.8173 3.8024L13.7752 3.77915L12.8238 3.25062Z" fill="#1C64F2" />
|
||||
<path d="M3.8238 4.41616C4.14566 4.23735 4.26162 3.83148 4.08281 3.50963C3.90401 3.18777 3.49814 3.07181 3.17628 3.25062L2.22493 3.77915L2.18284 3.8024C2.07615 3.86116 1.95045 3.9304 1.83382 4.02102C1.66295 4.06506 1.50977 4.17643 1.41731 4.34286C1.33065 4.49886 1.31323 4.67459 1.35493 4.83464C1.33229 4.98072 1.33281 5.12436 1.33326 5.24627L1.33338 5.29435V6.33339C1.33338 6.70158 1.63185 7.00006 2.00004 7.00006C2.36823 7.00006 2.66671 6.70158 2.66671 6.33339V5.79961L3.17632 6.08273C3.49817 6.26154 3.90404 6.14558 4.08285 5.82372C4.26166 5.50186 4.1457 5.096 3.82384 4.91719L3.3729 4.66666L3.8238 4.41616Z" fill="#1C64F2" />
|
||||
<path d="M2.66671 9.66672C2.66671 9.29853 2.36823 9.00006 2.00004 9.00006C1.63185 9.00006 1.33338 9.29853 1.33338 9.66672V10.7058L1.33326 10.7538C1.33262 10.9298 1.33181 11.1509 1.40069 11.3594C1.46024 11.5397 1.55759 11.7051 1.68622 11.8447C1.835 12.0061 2.02873 12.1128 2.18281 12.1977L2.22493 12.221L3.17628 12.7495C3.49814 12.9283 3.90401 12.8123 4.08281 12.4905C4.26162 12.1686 4.14566 11.7628 3.8238 11.584L2.87245 11.0554C2.76582 10.9962 2.71137 10.9656 2.67318 10.9413L2.66995 10.9392L2.66971 10.9354C2.66699 10.8902 2.66671 10.8277 2.66671 10.7058V9.66672Z" fill="#1C64F2" />
|
||||
<path d="M14.6667 9.66672C14.6667 9.29853 14.3682 9.00006 14 9.00006C13.6319 9.00006 13.3334 9.29853 13.3334 9.66672V10.7058C13.3334 10.8277 13.3331 10.8902 13.3304 10.9354L13.3301 10.9392L13.3269 10.9413C13.2887 10.9656 13.2343 10.9962 13.1276 11.0554L12.1763 11.584C11.8544 11.7628 11.7385 12.1686 11.9173 12.4905C12.0961 12.8123 12.5019 12.9283 12.8238 12.7495L13.7752 12.221L13.8172 12.1977C13.9713 12.1128 14.1651 12.0061 14.3139 11.8447C14.4425 11.7051 14.5398 11.5397 14.5994 11.3594C14.6683 11.1509 14.6675 10.9298 14.6668 10.7538L14.6667 10.7058V9.66672Z" fill="#1C64F2" />
|
||||
<path d="M6.82381 13.2506C6.50195 13.0718 6.09608 13.1878 5.91727 13.5096C5.73846 13.8315 5.85443 14.2374 6.17628 14.4162L7.15826 14.9617L7.19793 14.9839C7.29819 15.04 7.41625 15.1061 7.54696 15.1556C7.66589 15.2659 7.82512 15.3333 8.00008 15.3333C8.17507 15.3333 8.33431 15.2659 8.45324 15.1556C8.58391 15.1061 8.70193 15.04 8.80215 14.9839L8.84183 14.9617L9.82381 14.4162C10.1457 14.2374 10.2616 13.8315 10.0828 13.5096C9.90401 13.1878 9.49814 13.0718 9.17628 13.2506L8.66675 13.5337V13C8.66675 12.6318 8.36827 12.3333 8.00008 12.3333C7.63189 12.3333 7.33341 12.6318 7.33341 13V13.5337L6.82381 13.2506Z" fill="#1C64F2" />
|
||||
<path d="M6.82384 6.58385C6.50199 6.40505 6.09612 6.52101 5.91731 6.84286C5.7385 7.16472 5.85446 7.57059 6.17632 7.7494L7.33341 8.39223V9.66663C7.33341 10.0348 7.63189 10.3333 8.00008 10.3333C8.36827 10.3333 8.66675 10.0348 8.66675 9.66663V8.39223L9.82384 7.7494C10.1457 7.57059 10.2617 7.16472 10.0829 6.84286C9.90404 6.52101 9.49817 6.40505 9.17632 6.58385L8.00008 7.23732L6.82384 6.58385Z" fill="#1C64F2" />
|
||||
</svg>
|
||||
}
|
||||
title={t('appDebug.modelConfig.title')}
|
||||
>
|
||||
<div className='py-3 pl-10 pr-6 text-sm'>
|
||||
<div className="flex items-center justify-between my-5 h-9">
|
||||
<div>{t('appDebug.modelConfig.model')}</div>
|
||||
<ModelSelector
|
||||
defaultModel={{ model: modelId, provider }}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={({ provider, model }) => {
|
||||
const targetProvider = textGenerationModelList.find(modelItem => modelItem.provider === provider)
|
||||
const targetModelItem = targetProvider?.models.find(modelItem => modelItem.model === model)
|
||||
handleSelectModel({
|
||||
id: model,
|
||||
provider,
|
||||
mode: targetModelItem?.model_properties.mode as ModelModeType,
|
||||
features: targetModelItem?.features || [],
|
||||
})()
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{hasEnableParams && (
|
||||
<div className="border-b border-gray-100"></div>
|
||||
)}
|
||||
|
||||
{/* Tone type */}
|
||||
{['openai', 'azure_openai'].includes(provider) && (
|
||||
<div className="mt-5 mb-4">
|
||||
<div className="mb-3 text-sm text-gray-900">{t('appDebug.modelConfig.setTone')}</div>
|
||||
<Radio.Group className={cn('!rounded-lg', toneTabBgClassName)} value={toneId} onChange={handleToneChange}>
|
||||
<>
|
||||
{TONE_LIST.slice(0, 3).map(tone => (
|
||||
<div className='grow flex items-center' key={tone.id}>
|
||||
<Radio
|
||||
value={tone.id}
|
||||
className={cn(tone.id === toneId && 'rounded-md border border-gray-200 shadow-md', '!mr-0 grow !px-1 sm:!px-2 !justify-center text-[13px] font-medium')}
|
||||
labelClassName={cn(tone.id === toneId
|
||||
? ({
|
||||
1: 'text-[#6938EF]',
|
||||
2: 'text-[#444CE7]',
|
||||
3: 'text-[#107569]',
|
||||
})[toneId]
|
||||
: 'text-[#667085]', 'flex items-center space-x-2')}
|
||||
>
|
||||
<>
|
||||
{getToneIcon(tone.id)}
|
||||
{!isMobile && <div>{t(`common.model.tone.${tone.name}`) as string}</div>}
|
||||
<div className=""></div>
|
||||
</>
|
||||
</Radio>
|
||||
{tone.id !== toneId && tone.id + 1 !== toneId && (<div className='h-5 border-r border-gray-200'></div>)}
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
<Radio
|
||||
value={TONE_LIST[3].id}
|
||||
className={cn(toneId === 4 && 'rounded-md border border-gray-200 shadow-md', '!mr-0 grow !px-1 sm:!px-2 !justify-center text-[13px] font-medium')}
|
||||
labelClassName={cn('flex items-center space-x-2 ', toneId === 4 ? 'text-[#155EEF]' : 'text-[#667085]')}
|
||||
>
|
||||
<>
|
||||
{getToneIcon(TONE_LIST[3].id)}
|
||||
{!isMobile && <div>{t(`common.model.tone.${TONE_LIST[3].name}`) as string}</div>}
|
||||
</>
|
||||
</Radio>
|
||||
</Radio.Group>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Params */}
|
||||
<div className={cn(hasEnableParams && 'mt-4', 'space-y-4', !allParams[provider]?.[modelId] && 'flex items-center min-h-[200px]')}>
|
||||
{(allParams[provider]?.[modelId])
|
||||
? (
|
||||
currSupportParams.map(key => (<ParamItem
|
||||
key={key}
|
||||
id={key}
|
||||
name={t(`common.model.params.${key === 'stop' ? 'stop_sequences' : key}`)}
|
||||
tip={t(`common.model.params.${key === 'stop' ? 'stop_sequences' : key}Tip`)}
|
||||
{...currParams[key] as any}
|
||||
value={(completionParams as any)[key] as any}
|
||||
onChange={handleParamChange}
|
||||
inputType={key === 'stop' ? 'inputTag' : 'slider'}
|
||||
/>))
|
||||
)
|
||||
: (
|
||||
<Loading type='area' />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
maxTokenSettingTipVisible && (
|
||||
<div className='flex py-2 pr-4 pl-5 rounded-bl-xl rounded-br-xl bg-[#FFFAEB] border-t border-[#FEF0C7]'>
|
||||
<AlertTriangle className='shrink-0 mr-2 mt-[3px] w-3 h-3 text-[#F79009]' />
|
||||
<div className='mr-2 text-xs font-medium text-gray-700'>{t('common.model.params.maxTokenSettingTip')}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</Panel>
|
||||
)}
|
||||
</div>
|
||||
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(ConfigModel)
|
||||
@@ -1,29 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import type { ModelModeType } from '@/types/app'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
type: ModelModeType
|
||||
isHighlight?: boolean
|
||||
}
|
||||
|
||||
const ModelModeTypeLabel: FC<Props> = ({
|
||||
className,
|
||||
type,
|
||||
isHighlight,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(className, isHighlight ? 'border-indigo-300 text-indigo-600' : 'border-gray-300 text-gray-500', 'flex items-center h-4 px-1 border rounded text-xs font-semibold uppercase text-ellipsis overflow-hidden whitespace-nowrap')}
|
||||
>
|
||||
{t(`appDebug.modelConfig.modeType.${type}`)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(ModelModeTypeLabel)
|
||||
@@ -1,26 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
|
||||
export type IModelNameProps = {
|
||||
modelId: string
|
||||
modelDisplayName?: string
|
||||
}
|
||||
|
||||
export const supportI18nModelName = [
|
||||
'gpt-3.5-turbo', 'gpt-3.5-turbo-16k',
|
||||
'gpt-4', 'gpt-4-32k',
|
||||
'text-davinci-003', 'text-embedding-ada-002', 'whisper-1',
|
||||
'claude-instant-1', 'claude-2',
|
||||
]
|
||||
|
||||
const ModelName: FC<IModelNameProps> = ({
|
||||
modelDisplayName,
|
||||
}) => {
|
||||
return (
|
||||
<span className='text-ellipsis overflow-hidden whitespace-nowrap' title={modelDisplayName}>
|
||||
{modelDisplayName}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
export default React.memo(ModelName)
|
||||
@@ -1,95 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Slider from '@/app/components/base/slider'
|
||||
import TagInput from '@/app/components/base/tag-input'
|
||||
|
||||
export const getFitPrecisionValue = (num: number, precision: number | null) => {
|
||||
if (!precision || !(`${num}`).includes('.'))
|
||||
return num
|
||||
|
||||
const currNumPrecision = (`${num}`).split('.')[1].length
|
||||
if (currNumPrecision > precision)
|
||||
return parseFloat(num.toFixed(precision))
|
||||
|
||||
return num
|
||||
}
|
||||
|
||||
export type IParamIteProps = {
|
||||
id: string
|
||||
name: string
|
||||
tip: string
|
||||
value: number | string[]
|
||||
step?: number
|
||||
min?: number
|
||||
max: number
|
||||
precision: number | null
|
||||
onChange: (key: string, value: number | string[]) => void
|
||||
inputType?: 'inputTag' | 'slider'
|
||||
}
|
||||
|
||||
const TIMES_TEMPLATE = '1000000000000'
|
||||
const ParamItem: FC<IParamIteProps> = ({ id, name, tip, step = 0.1, min = 0, max, precision, value, inputType, onChange }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getToIntTimes = (num: number) => {
|
||||
if (precision)
|
||||
return parseInt(TIMES_TEMPLATE.slice(0, precision + 1), 10)
|
||||
if (num < 5)
|
||||
return 10
|
||||
return 1
|
||||
}
|
||||
|
||||
const times = getToIntTimes(max)
|
||||
|
||||
useEffect(() => {
|
||||
if (precision)
|
||||
onChange(id, getFitPrecisionValue(value, precision))
|
||||
}, [value, precision])
|
||||
return (
|
||||
<div className="flex items-center justify-between flex-wrap gap-y-2">
|
||||
<div className="flex flex-col flex-shrink-0">
|
||||
<div className="flex items-center">
|
||||
<span className="mr-[6px] text-gray-500 text-[13px] font-medium">{name}</span>
|
||||
{/* Give tooltip different tip to avoiding hide bug */}
|
||||
<Tooltip htmlContent={<div className="w-[200px] whitespace-pre-wrap">{tip}</div>} position='top' selector={`param-name-tooltip-${id}`}>
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8.66667 10.6667H8V8H7.33333M8 5.33333H8.00667M14 8C14 8.78793 13.8448 9.56815 13.5433 10.2961C13.2417 11.0241 12.7998 11.6855 12.2426 12.2426C11.6855 12.7998 11.0241 13.2417 10.2961 13.5433C9.56815 13.8448 8.78793 14 8 14C7.21207 14 6.43185 13.8448 5.7039 13.5433C4.97595 13.2417 4.31451 12.7998 3.75736 12.2426C3.20021 11.6855 2.75825 11.0241 2.45672 10.2961C2.15519 9.56815 2 8.78793 2 8C2 6.4087 2.63214 4.88258 3.75736 3.75736C4.88258 2.63214 6.4087 2 8 2C9.5913 2 11.1174 2.63214 12.2426 3.75736C13.3679 4.88258 14 6.4087 14 8Z" stroke="#9CA3AF" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
</Tooltip>
|
||||
</div>
|
||||
{inputType === 'inputTag' && <div className="text-gray-400 text-xs font-normal">{t('common.model.params.stop_sequencesPlaceholder')}</div>}
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
{inputType === 'inputTag'
|
||||
? <TagInput
|
||||
items={(value ?? []) as string[]}
|
||||
onChange={newSequences => onChange(id, newSequences)}
|
||||
customizedConfirmKey='Tab'
|
||||
/>
|
||||
: (
|
||||
<>
|
||||
<div className="mr-4 w-[120px]">
|
||||
<Slider value={value * times} min={min * times} max={max * times} onChange={(value) => {
|
||||
onChange(id, value / times)
|
||||
}} />
|
||||
</div>
|
||||
<input type="number" min={min} max={max} step={step} className="block w-[64px] h-9 leading-9 rounded-lg border-0 pl-1 pl py-1.5 bg-gray-50 text-gray-900 placeholder:text-gray-400 focus:ring-1 focus:ring-inset focus:ring-primary-600" value={value} onChange={(e) => {
|
||||
let value = getFitPrecisionValue(isNaN(parseFloat(e.target.value)) ? min : parseFloat(e.target.value), precision)
|
||||
if (value < min)
|
||||
value = min
|
||||
|
||||
if (value > max)
|
||||
value = max
|
||||
onChange(id, value)
|
||||
}} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(ParamItem)
|
||||
@@ -1,24 +0,0 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import I18n from '@/context/i18n'
|
||||
import type { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import ProviderConfig from '@/app/components/header/account-setting/model-page/configs'
|
||||
|
||||
export type IProviderNameProps = {
|
||||
provideName: ProviderEnum
|
||||
}
|
||||
|
||||
const ProviderName: FC<IProviderNameProps> = ({
|
||||
provideName,
|
||||
}) => {
|
||||
const { locale } = useContext(I18n)
|
||||
|
||||
return (
|
||||
<span>
|
||||
{ProviderConfig[provideName]?.selector?.name[locale]}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
export default React.memo(ProviderName)
|
||||
@@ -1,16 +1,17 @@
|
||||
'use client'
|
||||
import React, { FC } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
export interface IModalFootProps {
|
||||
export type IModalFootProps = {
|
||||
onConfirm: () => void
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
const ModalFoot: FC<IModalFootProps> = ({
|
||||
onConfirm,
|
||||
onCancel
|
||||
onCancel,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user