feat: enhance command execution and status retrieval in virtual environments with transport abstractions

This commit is contained in:
Yeuoly
2025-12-30 19:37:16 +08:00
parent bac5245cd0
commit 39091fe4df
7 changed files with 334 additions and 30 deletions

View File

@@ -43,7 +43,6 @@ class CommandStatus(BaseModel):
RUNNING = "running"
COMPLETED = "completed"
pid: int = Field(description="The process ID of the command.")
status: Status = Field(description="The status of the command execution.")
exit_code: int | None = Field(description="The return code of the command execution.")

View File

@@ -4,6 +4,7 @@ from io import BytesIO
from typing import Any
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.channel.transport import Transport
class VirtualEnvironment(ABC):
@@ -116,7 +117,9 @@ class VirtualEnvironment(ABC):
"""
@abstractmethod
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str]
) -> tuple[str, Transport, Transport, Transport]:
"""
Execute a command in the virtual environment.
@@ -125,12 +128,13 @@ class VirtualEnvironment(ABC):
command (list[str]): The command to execute as a list of strings.
Returns:
tuple[int, int, int, int]: A tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
tuple[int, Transport, Transport, Transport]: A tuple containing pid and 3 handle
to os.pipe(): (stdin, stdout, stderr).
After exuection, the 3 handles will be closed by caller.
"""
@abstractmethod
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""
Get the status of a command executed in the virtual environment.

View File

@@ -0,0 +1,26 @@
import os
from core.virtual_environment.channel.transport import Transport
class PipeTransport(Transport):
"""
A Transport implementation using OS pipes. it requires two file descriptors:
one for reading and one for writing.
NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen
"""
def __init__(self, r_fd: int, w_fd: int):
self.r_fd = r_fd
self.w_fd = w_fd
def write(self, data: bytes) -> None:
os.write(self.w_fd, data)
def read(self, n: int) -> bytes:
return os.read(self.r_fd, n)
def close(self) -> None:
os.close(self.r_fd)
os.close(self.w_fd)

View File

@@ -0,0 +1,21 @@
import socket
from core.virtual_environment.channel.transport import Transport
class SocketTransport(Transport):
"""
A Transport implementation using a socket.
"""
def __init__(self, sock: socket.SocketIO):
self.sock = sock
def write(self, data: bytes) -> None:
self.sock.write(data)
def read(self, n: int) -> bytes:
return self.sock.read(n)
def close(self) -> None:
self.sock.close()

View File

@@ -0,0 +1,25 @@
from abc import abstractmethod
from typing import Protocol
class Transport(Protocol):
@abstractmethod
def write(self, data: bytes) -> None:
"""
Write data to the transport.
"""
pass
@abstractmethod
def read(self, n: int) -> bytes:
"""
Read up to n bytes from the transport.
"""
pass
@abstractmethod
def close(self) -> None:
"""
Close the transport.
"""
pass

View File

@@ -1,28 +1,74 @@
from collections.abc import Mapping
import socket
import tarfile
from collections.abc import Mapping, Sequence
from functools import lru_cache
from typing import Any
from io import BytesIO
from pathlib import PurePosixPath
from typing import Any, cast
from uuid import uuid4
import docker.errors
from docker.models.containers import Container
import docker
from core.virtual_environment.__base.entities import Arch, Metadata
from core.virtual_environment.__base.exec import ArchNotSupportedError, VirtualEnvironmentLaunchFailedError
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.socket_transport import SocketTransport
from core.virtual_environment.channel.transport import Transport
"""
EXAMPLE:
from collections.abc import Mapping
from typing import Any
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
options: Mapping[str, Any] = {}
environment = DockerDaemonEnvironment(options=options)
connection_handle = environment.establish_connection()
pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command(
connection_handle, ["uname", "-a"]
)
print(f"Executed command with PID: {pid}")
# consume stdout
output = transport_stdout.read(1024)
print(f"Command output: {output.decode().strip()}")
environment.release_connection(connection_handle)
environment.release_environment()
"""
class DockerDaemonEnvironment(VirtualEnvironment):
_WORKING_DIR = "/workspace"
def construct_environment(self, options: Mapping[str, Any]) -> Metadata:
"""
Construct the Docker daemon virtual environment.
"""
docker_sock = options.get("docker_sock", "unix:///var/run/docker.sock")
docker_client = self.get_docker_daemon(docker_sock)
docker_client = self.get_docker_daemon(options.get("docker_sock", "unix:///var/run/docker.sock"))
# TODO: use a better image in practice
default_docker_image = options.get("docker_agent_image", "ubuntu:latest")
container_command = options.get("docker_agent_command", ["sleep", "infinity"])
container = docker_client.containers.run(image=default_docker_image, detach=True, remove=True)
container = docker_client.containers.run(
image=default_docker_image,
command=container_command,
detach=True,
remove=True,
stdin_open=True,
working_dir=self._WORKING_DIR,
)
# wait for the container to be fully started
container.reload()
@@ -35,8 +81,8 @@ class DockerDaemonEnvironment(VirtualEnvironment):
arch=self._get_container_architecture(container),
)
@lru_cache(maxsize=5)
@classmethod
@lru_cache(maxsize=5)
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
"""
Get the Docker daemon client.
@@ -45,16 +91,189 @@ class DockerDaemonEnvironment(VirtualEnvironment):
"""
return docker.DockerClient(base_url=docker_sock)
@classmethod
@lru_cache(maxsize=5)
def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient:
"""
Get the Docker low-level API client.
"""
return docker.APIClient(base_url=docker_sock)
def get_docker_sock(self) -> str:
return self.options.get("docker_sock", "unix:///var/run/docker.sock")
@property
def _working_dir(self) -> str:
return self._WORKING_DIR
def _get_container(self) -> Container:
docker_client = self.get_docker_daemon(self.get_docker_sock())
return docker_client.containers.get(self.metadata.id)
def _normalize_relative_path(self, path: str) -> PurePosixPath:
parts: list[str] = []
for part in PurePosixPath(path).parts:
if part in ("", ".", "/"):
continue
if part == "..":
if not parts:
raise ValueError("Path escapes the workspace.")
parts.pop()
continue
parts.append(part)
return PurePosixPath(*parts)
def _relative_path(self, path: str) -> PurePosixPath:
normalized = self._normalize_relative_path(path)
if normalized.parts:
return normalized
return PurePosixPath()
def _container_path(self, path: str) -> str:
relative = self._relative_path(path)
if not relative.parts:
return self._working_dir
return f"{self._working_dir}/{relative.as_posix()}"
def upload_file(self, path: str, content: BytesIO) -> None:
container = self._get_container()
relative_path = self._relative_path(path)
if not relative_path.parts:
raise ValueError("Upload path must point to a file within the workspace.")
payload = content.getvalue()
tar_stream = BytesIO()
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
tar_info.size = len(payload)
tar.addfile(tar_info, BytesIO(payload))
tar_stream.seek(0)
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
def download_file(self, path: str) -> BytesIO:
container = self._get_container()
container_path = self._container_path(path)
stream, _ = container.get_archive(container_path)
tar_stream = BytesIO()
for chunk in stream:
tar_stream.write(chunk)
tar_stream.seek(0)
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
members = [member for member in tar.getmembers() if member.isfile()]
if not members:
return BytesIO()
extracted = tar.extractfile(members[0])
if extracted is None:
return BytesIO()
return BytesIO(extracted.read())
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
container = self._get_container()
container_path = self._container_path(directory_path)
relative_base = self._relative_path(directory_path)
try:
stream, _ = container.get_archive(container_path)
except docker.errors.NotFound:
return []
tar_stream = BytesIO()
for chunk in stream:
tar_stream.write(chunk)
tar_stream.seek(0)
files: list[FileState] = []
archive_root = PurePosixPath(container_path).name
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
for member in tar.getmembers():
if not member.isfile():
continue
member_path = PurePosixPath(member.name)
if member_path.parts and member_path.parts[0] == archive_root:
member_path = PurePosixPath(*member_path.parts[1:])
if not member_path.parts:
continue
relative_path = relative_base / member_path
files.append(
FileState(
path=relative_path.as_posix(),
size=member.size,
created_at=int(member.mtime),
updated_at=int(member.mtime),
)
)
if len(files) >= limit:
break
return files
def establish_connection(self) -> ConnectionHandle:
return ConnectionHandle(id=uuid4().hex)
def release_connection(self, connection_handle: ConnectionHandle) -> None:
# No action needed for Docker exec connections
pass
def release_environment(self) -> None:
try:
container = self._get_container()
except docker.errors.NotFound:
return
try:
container.remove(force=True)
except docker.errors.NotFound:
return
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str]
) -> tuple[str, Transport, Transport, Transport]:
container = self._get_container()
container_id = container.id
if not isinstance(container_id, str) or not container_id:
raise RuntimeError("Docker container ID is not available for exec.")
api_client = self.get_docker_api_client(self.get_docker_sock())
exec_info: dict[str, object] = cast(
dict[str, object],
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
container_id,
cmd=command,
stdin=True,
stdout=True,
stderr=True,
tty=False,
workdir=self._working_dir,
),
)
if not isinstance(exec_info.get("Id"), str):
raise RuntimeError("Failed to create Docker exec instance.")
exec_id: str = str(exec_info.get("Id"))
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
transport = SocketTransport(raw_sock)
return exec_id, transport, transport, transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
api_client = self.get_docker_api_client(self.get_docker_sock())
inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] #
exit_code = inspect.get("ExitCode")
if inspect.get("Running") or exit_code is None:
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
if not isinstance(exit_code, int):
exit_code = None
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
def _get_container_architecture(self, container: Container) -> Arch:
"""
Get the architecture of the Docker container.
"""
container.reload()
arch_str: str = container.attrs["Architecture"]
match arch_str.lower():
case "x86_64" | "amd64":
return Arch.AMD64
case "aarch64" | "arm64":
return Arch.ARM64
case _:
raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")
return Arch.ARM64
# container.reload()
# arch_str = str(container.attrs["Architecture"])
# match arch_str.lower():
# case "x86_64" | "amd64":
# return Arch.AMD64
# case "aarch64" | "arm64":
# return Arch.ARM64
# case _:
# raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")

View File

@@ -11,6 +11,8 @@ from uuid import uuid4
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
from core.virtual_environment.__base.exec import ArchNotSupportedError
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from core.virtual_environment.channel.pipe_transport import PipeTransport
from core.virtual_environment.channel.transport import Transport
class LocalVirtualEnvironment(VirtualEnvironment):
@@ -114,7 +116,9 @@ class LocalVirtualEnvironment(VirtualEnvironment):
# No action needed for local without isolation
pass
def execute_command(self, connection_handle: ConnectionHandle, command: list[str]) -> tuple[int, int, int, int]:
def execute_command(
self, connection_handle: ConnectionHandle, command: list[str]
) -> tuple[str, Transport, Transport, Transport]:
"""
Execute a command in the local virtual environment.
@@ -156,10 +160,15 @@ class LocalVirtualEnvironment(VirtualEnvironment):
os.close(stdout_write_fd)
os.close(stderr_write_fd)
# Return the process ID and file descriptors for stdin, stdout, and stderr
return process.pid, stdin_write_fd, stdout_read_fd, stderr_read_fd
# Create PipeTransport instances for stdin, stdout, and stderr
stdin_transport = PipeTransport(r_fd=stdin_write_fd, w_fd=stdin_write_fd)
stdout_transport = PipeTransport(r_fd=stdout_read_fd, w_fd=stdout_read_fd)
stderr_transport = PipeTransport(r_fd=stderr_read_fd, w_fd=stderr_read_fd)
def get_command_status(self, connection_handle: ConnectionHandle, pid: int) -> CommandStatus:
# Return the process ID and file descriptors for stdin, stdout, and stderr
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
"""
Docstring for get_command_status
@@ -171,14 +180,15 @@ class LocalVirtualEnvironment(VirtualEnvironment):
:return: Description
:rtype: CommandStatus
"""
pid_int = int(pid)
try:
retcode = os.waitpid(pid, os.WNOHANG)[1]
retcode = os.waitpid(pid_int, os.WNOHANG)[1]
if retcode == 0:
return CommandStatus(status=CommandStatus.Status.RUNNING, pid=pid, exit_code=None)
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
else:
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=retcode)
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode)
except ChildProcessError:
return CommandStatus(status=CommandStatus.Status.COMPLETED, pid=pid, exit_code=None)
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
def _get_os_architecture(self) -> Arch:
"""