mirror of
https://github.com/langgenius/dify.git
synced 2026-01-26 01:04:18 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d38aa7138 | ||
|
|
7d2552b3f2 | ||
|
|
117a209ad4 | ||
|
|
071e7800a0 | ||
|
|
a76fde3d23 | ||
|
|
1fc57d7358 |
@@ -100,7 +100,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.15"
|
||||
self.CURRENT_VERSION = "0.3.16"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
i = 0
|
||||
normalized_embedding_results = []
|
||||
for text in embedding_queue_texts:
|
||||
hash = helper.generate_text_hash(text)
|
||||
|
||||
try:
|
||||
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
|
||||
embedding.set_embedding(embedding_results[i])
|
||||
vector = embedding_results[i]
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
normalized_embedding_results.append(normalized_embedding)
|
||||
embedding.set_embedding(normalized_embedding)
|
||||
db.session.add(embedding)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
text_embeddings.extend(embedding_results)
|
||||
text_embeddings.extend(normalized_embedding_results)
|
||||
return text_embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
try:
|
||||
embedding_results = self._embeddings.client.embed_query(text)
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
except Exception as ex:
|
||||
raise self._embeddings.handle_exceptions(ex)
|
||||
|
||||
@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
return embedding_results
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import decimal
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
client = HuggingFaceEndpoint(
|
||||
client = HuggingFaceEndpointLLM(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task='text2text-generation',
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
@@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
if 'huggingfacehub_endpoint_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.')
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpoint(
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
endpoint_url=credentials['huggingfacehub_endpoint_url'],
|
||||
task="text2text-generation",
|
||||
task=credentials['task_type'],
|
||||
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
|
||||
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
|
||||
)
|
||||
@@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if 'task_type' not in credentials:
|
||||
credentials['task_type'] = 'text-generation'
|
||||
|
||||
if credentials['huggingfacehub_api_token']:
|
||||
credentials['huggingfacehub_api_token'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
|
||||
39
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
vendored
Normal file
39
api/core/third_party/langchain/llms/huggingface_endpoint_llm.py
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
|
||||
"""HuggingFace Endpoint models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
endpoint_url = (
|
||||
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
hf = HuggingFaceEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
return values
|
||||
@@ -49,4 +49,4 @@ huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.2.0
|
||||
xinference==0.2.1
|
||||
@@ -19,7 +19,7 @@ from models.dataset import Dataset, DocumentSegment, DatasetQuery
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_document_count == 0:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
|
||||
@@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = {
|
||||
INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = {
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'valid_key',
|
||||
'huggingfacehub_endpoint_url': 'valid_url'
|
||||
'huggingfacehub_endpoint_url': 'valid_url',
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.15
|
||||
image: langgenius/dify-api:0.3.16
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -124,7 +124,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.15
|
||||
image: langgenius/dify-api:0.3.16
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -176,7 +176,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.15
|
||||
image: langgenius/dify-web:0.3.16
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
2
web/.gitignore
vendored
2
web/.gitignore
vendored
@@ -15,6 +15,8 @@
|
||||
# production
|
||||
/build
|
||||
|
||||
/.history
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
@@ -38,6 +38,7 @@ const config: ProviderConfig = {
|
||||
defaultValue: {
|
||||
model_type: 'text-generation',
|
||||
huggingfacehub_api_type: 'hosted_inference_api',
|
||||
task_type: 'text-generation',
|
||||
},
|
||||
validateKeys: (v?: FormValue) => {
|
||||
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
|
||||
@@ -51,10 +52,36 @@ const config: ProviderConfig = {
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'huggingfacehub_endpoint_url',
|
||||
'task_type',
|
||||
]
|
||||
}
|
||||
return []
|
||||
},
|
||||
filterValue: (v?: FormValue) => {
|
||||
let filteredKeys: string[] = []
|
||||
if (v?.huggingfacehub_api_type === 'hosted_inference_api') {
|
||||
filteredKeys = [
|
||||
'huggingfacehub_api_type',
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'model_type',
|
||||
]
|
||||
}
|
||||
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
|
||||
filteredKeys = [
|
||||
'huggingfacehub_api_type',
|
||||
'huggingfacehub_api_token',
|
||||
'model_name',
|
||||
'huggingfacehub_endpoint_url',
|
||||
'task_type',
|
||||
'model_type',
|
||||
]
|
||||
}
|
||||
return filteredKeys.reduce((prev: FormValue, next: string) => {
|
||||
prev[next] = v?.[next] || ''
|
||||
return prev
|
||||
}, {})
|
||||
},
|
||||
fields: [
|
||||
{
|
||||
type: 'radio',
|
||||
@@ -120,6 +147,32 @@ const config: ProviderConfig = {
|
||||
'zh-Hans': '在此输入您的端点 URL',
|
||||
},
|
||||
},
|
||||
{
|
||||
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
|
||||
type: 'radio',
|
||||
key: 'task_type',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'Task',
|
||||
'zh-Hans': 'Task',
|
||||
},
|
||||
options: [
|
||||
{
|
||||
key: 'text2text-generation',
|
||||
label: {
|
||||
'en': 'Text-to-Text Generation',
|
||||
'zh-Hans': 'Text-to-Text Generation',
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'text-generation',
|
||||
label: {
|
||||
'en': 'Text Generation',
|
||||
'zh-Hans': 'Text Generation',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -91,6 +91,7 @@ export type ProviderConfigModal = {
|
||||
icon: ReactElement
|
||||
defaultValue?: FormValue
|
||||
validateKeys?: string[] | ((v?: FormValue) => string[])
|
||||
filterValue?: (v?: FormValue) => FormValue
|
||||
fields: Field[]
|
||||
link: {
|
||||
href: string
|
||||
|
||||
@@ -124,8 +124,9 @@ const ModelPage = () => {
|
||||
updateModelList(ModelType.embeddings)
|
||||
mutateProviders()
|
||||
}
|
||||
const handleSave = async (v?: FormValue) => {
|
||||
if (v && modelModalConfig) {
|
||||
const handleSave = async (originValue?: FormValue) => {
|
||||
if (originValue && modelModalConfig) {
|
||||
const v = modelModalConfig.filterValue ? modelModalConfig.filterValue(originValue) : originValue
|
||||
let body, url
|
||||
if (ConfigurableProviders.includes(modelModalConfig.key)) {
|
||||
const { model_name, model_type, ...config } = v
|
||||
|
||||
@@ -68,7 +68,7 @@ const Form: FC<FormProps> = ({
|
||||
return true
|
||||
},
|
||||
run: () => {
|
||||
return validateModelProviderFn(modelModal!.key, v)
|
||||
return validateModelProviderFn(modelModal!.key, modelModal?.filterValue ? modelModal?.filterValue(v) : v)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"version": "0.3.15",
|
||||
"version": "0.3.16",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
|
||||
Reference in New Issue
Block a user