mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2
This commit is contained in:
@@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion",
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
@@ -155,8 +155,8 @@ class DataSourceNotionListApi(Resource):
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion",
|
||||
datasource_name="notion",
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
@@ -209,7 +209,7 @@ class DataSourceNotionApi(Resource):
|
||||
credential = datasource_provider_service.get_real_credential_by_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion",
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
@@ -43,7 +44,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
@@ -98,13 +103,24 @@ class DatasourceOAuthCallback(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
@@ -208,7 +224,8 @@ class DatasourceAuthListApi(Resource):
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@@ -82,19 +82,16 @@ class DatasourceProviderService:
|
||||
if key in credential_secret_variables:
|
||||
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
||||
return copy_credentials
|
||||
|
||||
def get_default_real_credential(
|
||||
self, tenant_id: str, provider: str, plugin_id: str
|
||||
) -> dict[str, Any]:
|
||||
|
||||
def get_default_real_credential(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
get default credential
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
datasource_provider = (
|
||||
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id,
|
||||
is_default=True,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id).first()
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, is_default=True, provider=provider, plugin_id=plugin_id)
|
||||
.first()
|
||||
)
|
||||
if not datasource_provider:
|
||||
return {}
|
||||
@@ -357,6 +354,35 @@ class DatasourceProviderService:
|
||||
f"{credential_type.get_name()}",
|
||||
)
|
||||
|
||||
def reauthorize_datasource_oauth_provider(
|
||||
self,
|
||||
name: str | None,
|
||||
tenant_id: str,
|
||||
provider_id: DatasourceProviderID,
|
||||
avatar_url: str | None,
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
update datasource oauth provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
target_provider = session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.OAUTH2
|
||||
)
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
|
||||
target_provider.encrypted_credentials = credentials
|
||||
target_provider.avatar_url = avatar_url or target_provider.avatar_url
|
||||
target_provider.name = name or target_provider.name
|
||||
session.commit()
|
||||
|
||||
def add_datasource_oauth_provider(
|
||||
self,
|
||||
name: str | None,
|
||||
@@ -625,7 +651,7 @@ class DatasourceProviderService:
|
||||
}
|
||||
)
|
||||
return datasource_credentials
|
||||
|
||||
|
||||
def get_hard_code_datasource_credentials(self, tenant_id: str) -> list[dict]:
|
||||
"""
|
||||
get hard code datasource credentials.
|
||||
@@ -637,14 +663,16 @@ class DatasourceProviderService:
|
||||
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
||||
datasource_credentials = []
|
||||
for datasource in datasources:
|
||||
if datasource.plugin_id in ["langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource"]:
|
||||
if datasource.plugin_id in [
|
||||
"langgenius/firecrawl_datasource",
|
||||
"langgenius/notion_datasource",
|
||||
"langgenius/jina_datasource",
|
||||
]:
|
||||
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
||||
credentials = self.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
redirect_uri = (
|
||||
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
||||
datasource_credentials.append(
|
||||
{
|
||||
"provider": datasource.provider,
|
||||
|
||||
@@ -11,7 +11,13 @@ class OAuthProxyService(BasePluginClient):
|
||||
__KEY_PREFIX__ = "oauth_proxy_context:"
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
|
||||
def create_proxy_context(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
@@ -31,6 +37,8 @@ class OAuthProxyService(BasePluginClient):
|
||||
"tenant_id": tenant_id,
|
||||
"provider": provider,
|
||||
}
|
||||
if credential_id:
|
||||
data["credential_id"] = credential_id
|
||||
redis_client.setex(
|
||||
f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
|
||||
OAuthProxyService.__MAX_AGE__,
|
||||
|
||||
@@ -48,7 +48,6 @@ import {
|
||||
isSupportCustomRunForm,
|
||||
} from '@/app/components/workflow/utils'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import type { CommonNodeType } from '@/app/components/workflow/types'
|
||||
import { BlockEnum, type Node, NodeRunningStatus } from '@/app/components/workflow/types'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
@@ -71,15 +70,16 @@ import {
|
||||
} from '@/app/components/plugins/plugin-auth'
|
||||
import { AuthCategory } from '@/app/components/plugins/plugin-auth'
|
||||
import { canFindTool } from '@/utils'
|
||||
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
|
||||
import type { CustomRunFormProps } from '@/app/components/workflow/nodes/data-source/types'
|
||||
import { DataSourceClassification } from '@/app/components/workflow/nodes/data-source/types'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import DataSourceBeforeRunForm from '@/app/components/workflow/nodes/data-source/before-run-form'
|
||||
|
||||
const getCustomRunForm = (nodeType: BlockEnum, payload: CommonNodeType): React.JSX.Element => {
|
||||
const getCustomRunForm = (params: CustomRunFormProps): React.JSX.Element => {
|
||||
const nodeType = params.payload.type
|
||||
switch (nodeType) {
|
||||
case BlockEnum.DataSource:
|
||||
return <DataSourceBeforeRunForm payload={payload as DataSourceNodeType} />
|
||||
return <DataSourceBeforeRunForm {...params} />
|
||||
default:
|
||||
return <div>Custom Run Form: {nodeType} not found</div>
|
||||
}
|
||||
@@ -227,6 +227,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
tabType,
|
||||
isRunAfterSingleRun,
|
||||
setTabType,
|
||||
handleAfterCustomSingleRun,
|
||||
singleRunParams,
|
||||
nodeInfo,
|
||||
setRunInputData,
|
||||
@@ -306,7 +307,11 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
}
|
||||
|
||||
if (isShowSingleRun) {
|
||||
const form = getCustomRunForm(data.type, data)
|
||||
const form = getCustomRunForm({
|
||||
payload: data,
|
||||
onSuccess: handleAfterCustomSingleRun,
|
||||
onCancel: hideSingleRun,
|
||||
})
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
import useInspectVarsCrud from '@/app/components/workflow/hooks/use-inspect-vars-crud'
|
||||
import { useInvalidLastRun } from '@/service/use-workflow'
|
||||
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import { isSupportCustomRunForm } from '@/app/components/workflow/utils'
|
||||
|
||||
const singleRunFormParamsHooks: Record<BlockEnum, any> = {
|
||||
[BlockEnum.LLM]: useLLMSingleRunFormParams,
|
||||
@@ -117,6 +118,7 @@ const useLastRun = <T>({
|
||||
const isIterationNode = blockType === BlockEnum.Iteration
|
||||
const isLoopNode = blockType === BlockEnum.Loop
|
||||
const isAggregatorNode = blockType === BlockEnum.VariableAggregator
|
||||
const isCustomRunNode = isSupportCustomRunForm(blockType)
|
||||
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
const {
|
||||
getData: getDataForCheckMore,
|
||||
@@ -299,10 +301,20 @@ const useLastRun = <T>({
|
||||
})
|
||||
}
|
||||
|
||||
const handleAfterCustomSingleRun = () => {
|
||||
invalidLastRun()
|
||||
setTabType(TabType.lastRun)
|
||||
hideSingleRun()
|
||||
}
|
||||
|
||||
const handleSingleRun = () => {
|
||||
const { isValid } = checkValid()
|
||||
if(!isValid)
|
||||
return
|
||||
if(isCustomRunNode) {
|
||||
showSingleRun()
|
||||
return
|
||||
}
|
||||
const vars = singleRunParams?.getDependentVars?.()
|
||||
// no need to input params
|
||||
if (isAggregatorNode ? checkAggregatorVarsSet(vars) : isAllVarsHasValue(vars)) {
|
||||
@@ -323,6 +335,7 @@ const useLastRun = <T>({
|
||||
tabType,
|
||||
isRunAfterSingleRun,
|
||||
setTabType: handleTabClicked,
|
||||
handleAfterCustomSingleRun,
|
||||
singleRunParams,
|
||||
nodeInfo,
|
||||
setRunInputData,
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import type { DataSourceNodeType } from './types'
|
||||
import type { CustomRunFormProps, DataSourceNodeType } from './types'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
type Props = {
|
||||
payload: DataSourceNodeType
|
||||
}
|
||||
|
||||
const BeforeRunForm: FC<Props> = ({
|
||||
const BeforeRunForm: FC<CustomRunFormProps> = ({
|
||||
payload,
|
||||
onSuccess,
|
||||
onCancel,
|
||||
}) => {
|
||||
return (
|
||||
<div>
|
||||
DataSource: {payload.datasource_name}
|
||||
DataSource: {(payload as DataSourceNodeType).datasource_name}
|
||||
<div className='mt-3 flex justify-center space-x-2'>
|
||||
<Button onClick={onSuccess} variant='primary'>Have runned</Button>
|
||||
<Button onClick={onCancel}>Cancel</Button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -28,3 +28,9 @@ export type DataSourceNodeType = CommonNodeType & {
|
||||
datasource_parameters: ToolVarInputs
|
||||
datasource_configurations: Record<string, any>
|
||||
}
|
||||
|
||||
export type CustomRunFormProps = {
|
||||
payload: CommonNodeType
|
||||
onSuccess: () => void
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user