Compare commits

..

1 Commits

Author SHA1 Message Date
Joel
c110888aee feat: agent app support generate prompt (#7007) 2024-08-06 17:43:54 +08:00
6 changed files with 18 additions and 21 deletions

View File

@@ -28,7 +28,7 @@ class RetrievalService:
@classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
weights: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
@@ -36,6 +36,10 @@ class RetrievalService:
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
keyword_search_documents = []
embedding_search_documents = []
full_text_search_documents = []
hybrid_search_documents = []
threads = []
exceptions = []
# retrieval_model source with keyword

View File

@@ -278,7 +278,6 @@ class DatasetRetrieval:
query=query,
top_k=top_k, score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
weights=retrieval_model_config.get('weights', None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -432,12 +431,10 @@ class DatasetRetrieval:
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None)
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)

View File

@@ -177,12 +177,10 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None)
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)

View File

@@ -14,7 +14,6 @@ default_retrieval_model = {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'reranking_mode': 'reranking_model',
'top_k': 2,
'score_threshold_enabled': False
}
@@ -72,15 +71,14 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None),
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None,
weights=retrieval_model.get('weights', None),
)
else:

View File

@@ -42,11 +42,11 @@ class HitTestingService:
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
top_k=retrieval_model.get('top_k', 2),
score_threshold=retrieval_model.get('score_threshold', .0)
score_threshold=retrieval_model['score_threshold']
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None),
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
reranking_model=retrieval_model['reranking_model']
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode', None),
weights=retrieval_model.get('weights', None),
)

View File

@@ -166,7 +166,7 @@ const Prompt: FC<ISimplePromptInput> = ({
)}
</div>
<div className='flex items-center'>
{!isAgent && !readonly && !isMobile && (
{!readonly && !isMobile && (
<AutomaticBtn onClick={showAutomaticTrue} />
)}
</div>