mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 23:04:12 +00:00
feat: enhance command execution and status retrieval in virtual environments with transport abstractions
This commit is contained in:
@@ -43,7 +43,6 @@ class CommandStatus(BaseModel):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
||||
pid: int = Field(description="The process ID of the command.")
|
||||
status: Status = Field(description="The status of the command execution.")
|
||||
exit_code: int | None = Field(description="The return code of the command execution.")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.channel.transport import Transport
|
||||
|
||||
|
||||
class VirtualEnvironment(ABC):
|
||||
@@ -116,7 +117,9 @@ class VirtualEnvironment(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str]
|
||||
) -> tuple[str, Transport, Transport, Transport]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
|
||||
@@ -125,12 +128,13 @@ class VirtualEnvironment(ABC):
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
|
||||
Returns:
|
||||
tuple[int, int, int, int]: A tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
|
||||
tuple[int, Transport, Transport, Transport]: A tuple containing pid and 3 handle
|
||||
to os.pipe(): (stdin, stdout, stderr).
|
||||
After exuection, the 3 handles will be closed by caller.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Get the status of a command executed in the virtual environment.
|
||||
|
||||
|
||||
26
api/core/virtual_environment/channel/pipe_transport.py
Normal file
26
api/core/virtual_environment/channel/pipe_transport.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
|
||||
from core.virtual_environment.channel.transport import Transport
|
||||
|
||||
|
||||
class PipeTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using OS pipes. it requires two file descriptors:
|
||||
one for reading and one for writing.
|
||||
|
||||
NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen
|
||||
"""
|
||||
|
||||
def __init__(self, r_fd: int, w_fd: int):
|
||||
self.r_fd = r_fd
|
||||
self.w_fd = w_fd
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
os.write(self.w_fd, data)
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
return os.read(self.r_fd, n)
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.r_fd)
|
||||
os.close(self.w_fd)
|
||||
21
api/core/virtual_environment/channel/socket_transport.py
Normal file
21
api/core/virtual_environment/channel/socket_transport.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import socket
|
||||
|
||||
from core.virtual_environment.channel.transport import Transport
|
||||
|
||||
|
||||
class SocketTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using a socket.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.sock.write(data)
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
return self.sock.read(n)
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
25
api/core/virtual_environment/channel/transport.py
Normal file
25
api/core/virtual_environment/channel/transport.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class Transport(Protocol):
|
||||
@abstractmethod
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the transport.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read(self, n: int) -> bytes:
|
||||
"""
|
||||
Read up to n bytes from the transport.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the transport.
|
||||
"""
|
||||
pass
|
||||
@@ -1,28 +1,74 @@
|
||||
from collections.abc import Mapping
|
||||
import socket
|
||||
import tarfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import docker.errors
|
||||
from docker.models.containers import Container
|
||||
|
||||
import docker
|
||||
from core.virtual_environment.__base.entities import Arch, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError, VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.socket_transport import SocketTransport
|
||||
from core.virtual_environment.channel.transport import Transport
|
||||
|
||||
"""
|
||||
EXAMPLE:
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {}
|
||||
|
||||
|
||||
environment = DockerDaemonEnvironment(options=options)
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command(
|
||||
connection_handle, ["uname", "-a"]
|
||||
)
|
||||
|
||||
print(f"Executed command with PID: {pid}")
|
||||
|
||||
# consume stdout
|
||||
output = transport_stdout.read(1024)
|
||||
print(f"Command output: {output.decode().strip()}")
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
_WORKING_DIR = "/workspace"
|
||||
|
||||
def construct_environment(self, options: Mapping[str, Any]) -> Metadata:
|
||||
"""
|
||||
Construct the Docker daemon virtual environment.
|
||||
"""
|
||||
|
||||
docker_sock = options.get("docker_sock", "unix:///var/run/docker.sock")
|
||||
docker_client = self.get_docker_daemon(docker_sock)
|
||||
docker_client = self.get_docker_daemon(options.get("docker_sock", "unix:///var/run/docker.sock"))
|
||||
|
||||
# TODO: use a better image in practice
|
||||
default_docker_image = options.get("docker_agent_image", "ubuntu:latest")
|
||||
container_command = options.get("docker_agent_command", ["sleep", "infinity"])
|
||||
|
||||
container = docker_client.containers.run(image=default_docker_image, detach=True, remove=True)
|
||||
container = docker_client.containers.run(
|
||||
image=default_docker_image,
|
||||
command=container_command,
|
||||
detach=True,
|
||||
remove=True,
|
||||
stdin_open=True,
|
||||
working_dir=self._WORKING_DIR,
|
||||
)
|
||||
|
||||
# wait for the container to be fully started
|
||||
container.reload()
|
||||
@@ -35,8 +81,8 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
arch=self._get_container_architecture(container),
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=5)
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
|
||||
"""
|
||||
Get the Docker daemon client.
|
||||
@@ -45,16 +91,189 @@ class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
return docker.DockerClient(base_url=docker_sock)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient:
|
||||
"""
|
||||
Get the Docker low-level API client.
|
||||
"""
|
||||
return docker.APIClient(base_url=docker_sock)
|
||||
|
||||
def get_docker_sock(self) -> str:
|
||||
return self.options.get("docker_sock", "unix:///var/run/docker.sock")
|
||||
|
||||
@property
|
||||
def _working_dir(self) -> str:
|
||||
return self._WORKING_DIR
|
||||
|
||||
def _get_container(self) -> Container:
|
||||
docker_client = self.get_docker_daemon(self.get_docker_sock())
|
||||
return docker_client.containers.get(self.metadata.id)
|
||||
|
||||
def _normalize_relative_path(self, path: str) -> PurePosixPath:
|
||||
parts: list[str] = []
|
||||
for part in PurePosixPath(path).parts:
|
||||
if part in ("", ".", "/"):
|
||||
continue
|
||||
if part == "..":
|
||||
if not parts:
|
||||
raise ValueError("Path escapes the workspace.")
|
||||
parts.pop()
|
||||
continue
|
||||
parts.append(part)
|
||||
return PurePosixPath(*parts)
|
||||
|
||||
def _relative_path(self, path: str) -> PurePosixPath:
|
||||
normalized = self._normalize_relative_path(path)
|
||||
if normalized.parts:
|
||||
return normalized
|
||||
return PurePosixPath()
|
||||
|
||||
def _container_path(self, path: str) -> str:
|
||||
relative = self._relative_path(path)
|
||||
if not relative.parts:
|
||||
return self._working_dir
|
||||
return f"{self._working_dir}/{relative.as_posix()}"
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
container = self._get_container()
|
||||
relative_path = self._relative_path(path)
|
||||
if not relative_path.parts:
|
||||
raise ValueError("Upload path must point to a file within the workspace.")
|
||||
|
||||
payload = content.getvalue()
|
||||
tar_stream = BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
|
||||
tar_info.size = len(payload)
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(path)
|
||||
stream, _ = container.get_archive(container_path)
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
members = [member for member in tar.getmembers() if member.isfile()]
|
||||
if not members:
|
||||
return BytesIO()
|
||||
extracted = tar.extractfile(members[0])
|
||||
if extracted is None:
|
||||
return BytesIO()
|
||||
return BytesIO(extracted.read())
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(directory_path)
|
||||
relative_base = self._relative_path(directory_path)
|
||||
try:
|
||||
stream, _ = container.get_archive(container_path)
|
||||
except docker.errors.NotFound:
|
||||
return []
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
files: list[FileState] = []
|
||||
archive_root = PurePosixPath(container_path).name
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile():
|
||||
continue
|
||||
member_path = PurePosixPath(member.name)
|
||||
if member_path.parts and member_path.parts[0] == archive_root:
|
||||
member_path = PurePosixPath(*member_path.parts[1:])
|
||||
if not member_path.parts:
|
||||
continue
|
||||
relative_path = relative_base / member_path
|
||||
files.append(
|
||||
FileState(
|
||||
path=relative_path.as_posix(),
|
||||
size=member.size,
|
||||
created_at=int(member.mtime),
|
||||
updated_at=int(member.mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
# No action needed for Docker exec connections
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
try:
|
||||
container = self._get_container()
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str]
|
||||
) -> tuple[str, Transport, Transport, Transport]:
|
||||
container = self._get_container()
|
||||
container_id = container.id
|
||||
if not isinstance(container_id, str) or not container_id:
|
||||
raise RuntimeError("Docker container ID is not available for exec.")
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
exec_info: dict[str, object] = cast(
|
||||
dict[str, object],
|
||||
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
|
||||
container_id,
|
||||
cmd=command,
|
||||
stdin=True,
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=False,
|
||||
workdir=self._working_dir,
|
||||
),
|
||||
)
|
||||
|
||||
if not isinstance(exec_info.get("Id"), str):
|
||||
raise RuntimeError("Failed to create Docker exec instance.")
|
||||
|
||||
exec_id: str = str(exec_info.get("Id"))
|
||||
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
transport = SocketTransport(raw_sock)
|
||||
return exec_id, transport, transport, transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] #
|
||||
exit_code = inspect.get("ExitCode")
|
||||
if inspect.get("Running") or exit_code is None:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
if not isinstance(exit_code, int):
|
||||
exit_code = None
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _get_container_architecture(self, container: Container) -> Arch:
|
||||
"""
|
||||
Get the architecture of the Docker container.
|
||||
"""
|
||||
container.reload()
|
||||
arch_str: str = container.attrs["Architecture"]
|
||||
match arch_str.lower():
|
||||
case "x86_64" | "amd64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "arm64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")
|
||||
return Arch.ARM64
|
||||
|
||||
# container.reload()
|
||||
# arch_str = str(container.attrs["Architecture"])
|
||||
# match arch_str.lower():
|
||||
# case "x86_64" | "amd64":
|
||||
# return Arch.AMD64
|
||||
# case "aarch64" | "arm64":
|
||||
# return Arch.ARM64
|
||||
# case _:
|
||||
# raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")
|
||||
|
||||
@@ -11,6 +11,8 @@ from uuid import uuid4
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.pipe_transport import PipeTransport
|
||||
from core.virtual_environment.channel.transport import Transport
|
||||
|
||||
|
||||
class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
@@ -114,7 +116,9 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
# No action needed for local without isolation
|
||||
pass
|
||||
|
||||
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str]
|
||||
) -> tuple[str, Transport, Transport, Transport]:
|
||||
"""
|
||||
Execute a command in the local virtual environment.
|
||||
|
||||
@@ -156,10 +160,15 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
os.close(stdout_write_fd)
|
||||
os.close(stderr_write_fd)
|
||||
|
||||
# Return the process ID and file descriptors for stdin, stdout, and stderr
|
||||
return process.pid, stdin_write_fd, stdout_read_fd, stderr_read_fd
|
||||
# Create PipeTransport instances for stdin, stdout, and stderr
|
||||
stdin_transport = PipeTransport(r_fd=stdin_write_fd, w_fd=stdin_write_fd)
|
||||
stdout_transport = PipeTransport(r_fd=stdout_read_fd, w_fd=stdout_read_fd)
|
||||
stderr_transport = PipeTransport(r_fd=stderr_read_fd, w_fd=stderr_read_fd)
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
|
||||
# Return the process ID and file descriptors for stdin, stdout, and stderr
|
||||
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Docstring for get_command_status
|
||||
|
||||
@@ -171,14 +180,15 @@ class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
:return: Description
|
||||
:rtype: CommandStatus
|
||||
"""
|
||||
pid_int = int(pid)
|
||||
try:
|
||||
retcode = os.waitpid(pid, os.WNOHANG)[1]
|
||||
retcode = os.waitpid(pid_int, os.WNOHANG)[1]
|
||||
if retcode == 0:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, pid=pid, exit_code=None)
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
else:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=retcode)
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode)
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=None)
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
def _get_os_architecture(self) -> Arch:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user