mirror of
https://github.com/langgenius/dify.git
synced 2026-01-01 20:17:16 +00:00
Compare commits
14 Commits
refactor/q
...
feat/suppo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf7e2d5d75 | ||
|
|
2673fe05a5 | ||
|
|
180fdffab1 | ||
|
|
62e422f75a | ||
|
|
41565e91ed | ||
|
|
c9610e9949 | ||
|
|
29dc083d8d | ||
|
|
f679065d2c | ||
|
|
0a97e87a8e | ||
|
|
4d81455a83 | ||
|
|
39091fe4df | ||
|
|
bac5245cd0 | ||
|
|
274f9a3f32 | ||
|
|
a513ab9a59 |
60
api/controllers/console/workspace/dsl.py
Normal file
60
api/controllers/console/workspace/dsl.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, location="json")
|
||||
.add_argument("current_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id: str = args["app_id"]
|
||||
current_node_id: str = args["current_node_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
58
api/core/virtual_environment/__base/entities.py
Normal file
58
api/core/virtual_environment/__base/entities.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Arch(StrEnum):
|
||||
"""
|
||||
Architecture types for virtual environments.
|
||||
"""
|
||||
|
||||
ARM64 = "arm64"
|
||||
AMD64 = "amd64"
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
"""
|
||||
Returned metadata about a virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the virtual environment.")
|
||||
arch: Arch = Field(description="Which architecture was used to create the virtual environment.")
|
||||
store: Mapping[str, Any] = Field(
|
||||
default_factory=dict, description="The store information of the virtual environment., Additional data."
|
||||
)
|
||||
|
||||
|
||||
class ConnectionHandle(BaseModel):
|
||||
"""
|
||||
Handle for managing connections to the virtual environment.
|
||||
"""
|
||||
|
||||
id: str = Field(description="The unique identifier of the connection handle.")
|
||||
|
||||
|
||||
class CommandStatus(BaseModel):
|
||||
"""
|
||||
Status of a command executed in the virtual environment.
|
||||
"""
|
||||
|
||||
class Status(StrEnum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
||||
status: Status = Field(description="The status of the command execution.")
|
||||
exit_code: int | None = Field(description="The return code of the command execution.")
|
||||
|
||||
|
||||
class FileState(BaseModel):
|
||||
"""
|
||||
State of a file in the virtual environment.
|
||||
"""
|
||||
|
||||
size: int = Field(description="The size of the file in bytes.")
|
||||
path: str = Field(description="The path of the file in the virtual environment.")
|
||||
created_at: int = Field(description="The creation timestamp of the file.")
|
||||
updated_at: int = Field(description="The last modified timestamp of the file.")
|
||||
16
api/core/virtual_environment/__base/exec.py
Normal file
16
api/core/virtual_environment/__base/exec.py
Normal file
@@ -0,0 +1,16 @@
|
||||
class ArchNotSupportedError(Exception):
|
||||
"""Exception raised when the architecture is not supported."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class VirtualEnvironmentLaunchFailedError(Exception):
|
||||
"""Exception raised when launching the virtual environment fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotSupportedOperationError(Exception):
|
||||
"""Exception raised when an operation is not supported."""
|
||||
|
||||
pass
|
||||
146
api/core/virtual_environment/__base/virtual_environment.py
Normal file
146
api/core/virtual_environment/__base/virtual_environment.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.__base.entities import CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class VirtualEnvironment(ABC):
|
||||
"""
|
||||
Base class for virtual environment implementations.
|
||||
"""
|
||||
|
||||
def __init__(self, options: Mapping[str, Any], environments: Mapping[str, str] | None = None) -> None:
|
||||
"""
|
||||
Initialize the virtual environment with metadata.
|
||||
"""
|
||||
|
||||
self.options = options
|
||||
self.metadata = self.construct_environment(options, environments or {})
|
||||
|
||||
@abstractmethod
|
||||
def construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the unique identifier for the virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The unique identifier of the virtual environment.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The destination path in the virtual environment.
|
||||
content (BytesIO): The content of the file to upload.
|
||||
|
||||
Raises:
|
||||
Exception: If the file cannot be uploaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the virtual environment.
|
||||
|
||||
Args:
|
||||
source_path (str): The source path in the virtual environment.
|
||||
Returns:
|
||||
BytesIO: The content of the downloaded file.
|
||||
Raises:
|
||||
Exception: If the file cannot be downloaded.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the virtual environment.
|
||||
|
||||
Args:
|
||||
directory_path (str): The directory path in the virtual environment.
|
||||
limit (int): The maximum number of files(including recursive paths) to return.
|
||||
Returns:
|
||||
Sequence[FileState]: A list of file states in the specified directory.
|
||||
Raises:
|
||||
Exception: If the files cannot be listed.
|
||||
|
||||
Example:
|
||||
If the directory structure is like:
|
||||
/dir
|
||||
/subdir1
|
||||
file1.txt
|
||||
/subdir2
|
||||
file2.txt
|
||||
And limit is 2, the returned list may look like:
|
||||
[
|
||||
FileState(path="/dir/subdir1/file1.txt", is_directory=False, size=1234, created_at=..., updated_at=...),
|
||||
FileState(path="/dir/subdir2", is_directory=True, size=0, created_at=..., updated_at=...),
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the virtual environment.
|
||||
|
||||
Returns:
|
||||
ConnectionHandle: Handle for managing the connection to the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be established.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
|
||||
Raises:
|
||||
Exception: If the connection cannot be released.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the virtual environment.
|
||||
|
||||
Raises:
|
||||
Exception: If the environment cannot be released.
|
||||
Multiple calls to `release_environment` with the same `environment_id` is acceptable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
command (list[str]): The command to execute as a list of strings.
|
||||
|
||||
Returns:
|
||||
tuple[int, TransportWriteCloser, TransportReadCloser, TransportReadCloser]
|
||||
a tuple containing pid and 3 handle to os.pipe(): (stdin, stdout, stderr).
|
||||
After exuection, the 3 handles will be closed by caller.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Get the status of a command executed in the virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The handle for managing the connection.
|
||||
pid (int): The process ID of the command.
|
||||
Returns:
|
||||
CommandStatus: The status of the command execution.
|
||||
"""
|
||||
4
api/core/virtual_environment/channel/exec.py
Normal file
4
api/core/virtual_environment/channel/exec.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class TransportEOFError(Exception):
|
||||
"""Exception raised when attempting to read from a closed transport."""
|
||||
|
||||
pass
|
||||
72
api/core/virtual_environment/channel/pipe_transport.py
Normal file
72
api/core/virtual_environment/channel/pipe_transport.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class PipeTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using OS pipes. it requires two file descriptors:
|
||||
one for reading and one for writing.
|
||||
|
||||
NOTE: r_fd and w_fd must be a pair created by os.pipe(). or returned from subprocess.Popen
|
||||
|
||||
NEVER FORGET TO CALL `close()` METHOD TO AVOID FILE DESCRIPTOR LEAKAGE.
|
||||
"""
|
||||
|
||||
def __init__(self, r_fd: int, w_fd: int):
|
||||
self.r_fd = r_fd
|
||||
self.w_fd = w_fd
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
os.write(self.w_fd, data)
|
||||
except OSError:
|
||||
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
data = os.read(self.r_fd, n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Pipe reached")
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.r_fd)
|
||||
os.close(self.w_fd)
|
||||
|
||||
|
||||
class PipeReadCloser(TransportReadCloser):
|
||||
"""
|
||||
A Transport implementation using OS pipe for reading.
|
||||
"""
|
||||
|
||||
def __init__(self, r_fd: int):
|
||||
self.r_fd = r_fd
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
data = os.read(self.r_fd, n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Pipe reached")
|
||||
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.r_fd)
|
||||
|
||||
|
||||
class PipeWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A Transport implementation using OS pipe for writing.
|
||||
"""
|
||||
|
||||
def __init__(self, w_fd: int):
|
||||
self.w_fd = w_fd
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
os.write(self.w_fd, data)
|
||||
except OSError:
|
||||
raise TransportEOFError("Pipe write error, maybe the read end is closed")
|
||||
|
||||
def close(self) -> None:
|
||||
os.close(self.w_fd)
|
||||
100
api/core/virtual_environment/channel/queue_transport.py
Normal file
100
api/core/virtual_environment/channel/queue_transport.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from queue import Queue
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser
|
||||
|
||||
|
||||
class QueueTransportReadCloser(TransportReadCloser):
|
||||
"""
|
||||
Transport implementation using queues for inter-thread communication.
|
||||
|
||||
Usage:
|
||||
q_transport = QueueTransportReadCloser()
|
||||
write_handler = q_transport.get_write_handler()
|
||||
|
||||
# In writer thread
|
||||
write_handler.write(b"data")
|
||||
|
||||
# In reader thread
|
||||
data = q_transport.read(1024)
|
||||
|
||||
# Close transport when done
|
||||
q_transport.close()
|
||||
"""
|
||||
|
||||
class WriteHandler:
|
||||
"""
|
||||
A write handler that writes data to a queue.
|
||||
"""
|
||||
|
||||
def __init__(self, queue: Queue[bytes | None]) -> None:
|
||||
self.queue = queue
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.queue.put(data)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the QueueTransportReadCloser with write function.
|
||||
"""
|
||||
self.q = Queue[bytes | None]()
|
||||
self._read_buffer = bytearray()
|
||||
self._closed = False
|
||||
self._write_channel_closed = False
|
||||
|
||||
def get_write_handler(self) -> WriteHandler:
|
||||
"""
|
||||
Get a write handler that writes to the internal queue.
|
||||
"""
|
||||
return QueueTransportReadCloser.WriteHandler(self.q)
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the transport by putting a sentinel value in the queue.
|
||||
"""
|
||||
if self._write_channel_closed:
|
||||
raise TransportEOFError("Write channel already closed")
|
||||
|
||||
self._write_channel_closed = True
|
||||
self.q.put(None)
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
"""
|
||||
Read up to n bytes from the queue.
|
||||
|
||||
NEVER USE IT IN A MULTI-THREADED CONTEXT WITHOUT PROPER SYNCHRONIZATION.
|
||||
"""
|
||||
if n <= 0:
|
||||
return b""
|
||||
|
||||
if self._closed:
|
||||
raise TransportEOFError("Transport is closed")
|
||||
|
||||
to_return = self._drain_buffer(n)
|
||||
while len(to_return) < n and not self._closed:
|
||||
chunk = self.q.get()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
raise TransportEOFError("Transport is closed")
|
||||
|
||||
self._read_buffer.extend(chunk)
|
||||
|
||||
if n - len(to_return) > 0:
|
||||
# Drain the buffer if we still need more data
|
||||
to_return += self._drain_buffer(n - len(to_return))
|
||||
else:
|
||||
# No more data needed, break
|
||||
break
|
||||
|
||||
if self.q.qsize() == 0:
|
||||
# If no more data is available, break to return what we have
|
||||
break
|
||||
|
||||
return to_return
|
||||
|
||||
def _drain_buffer(self, n: int) -> bytes:
|
||||
data = bytes(self._read_buffer[:n])
|
||||
del self._read_buffer[:n]
|
||||
return data
|
||||
70
api/core/virtual_environment/channel/socket_transport.py
Normal file
70
api/core/virtual_environment/channel/socket_transport.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import socket
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.transport import Transport, TransportReadCloser, TransportWriteCloser
|
||||
|
||||
|
||||
class SocketTransport(Transport):
|
||||
"""
|
||||
A Transport implementation using a socket.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
self.sock.write(data)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
try:
|
||||
data = self.sock.read(n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Socket reached")
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket connection reset")
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class SocketReadCloser(TransportReadCloser):
|
||||
"""
|
||||
A Transport implementation using a socket for reading.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def read(self, n: int) -> bytes:
|
||||
try:
|
||||
data = self.sock.read(n)
|
||||
if data == b"":
|
||||
raise TransportEOFError("End of Socket reached")
|
||||
return data
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket connection reset")
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class SocketWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A Transport implementation using a socket for writing.
|
||||
"""
|
||||
|
||||
def __init__(self, sock: socket.SocketIO):
|
||||
self.sock = sock
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
try:
|
||||
self.sock.write(data)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
raise TransportEOFError("Socket write error, maybe the read end is closed")
|
||||
|
||||
def close(self) -> None:
|
||||
self.sock.close()
|
||||
80
api/core/virtual_environment/channel/transport.py
Normal file
80
api/core/virtual_environment/channel/transport.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TransportCloser(Protocol):
|
||||
"""
|
||||
Transport that can be closed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the transport.
|
||||
"""
|
||||
|
||||
|
||||
class TransportWriter(Protocol):
|
||||
"""
|
||||
Transport that can be written to.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the transport.
|
||||
|
||||
Raises TransportEOFError if the transport is closed.
|
||||
"""
|
||||
|
||||
|
||||
class TransportReader(Protocol):
|
||||
"""
|
||||
Transport that can be read from.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read(self, n: int) -> bytes:
|
||||
"""
|
||||
Read up to n bytes from the transport.
|
||||
|
||||
Raises TransportEOFError if the end of the transport is reached.
|
||||
"""
|
||||
|
||||
|
||||
class TransportReadCloser(TransportReader, TransportCloser):
|
||||
"""
|
||||
Transport that can be read from and closed.
|
||||
"""
|
||||
|
||||
|
||||
class TransportWriteCloser(TransportWriter, TransportCloser):
|
||||
"""
|
||||
Transport that can be written to and closed.
|
||||
"""
|
||||
|
||||
|
||||
class Transport(TransportReader, TransportWriter, TransportCloser):
|
||||
"""
|
||||
Transport that can be read from, written to, and closed.
|
||||
"""
|
||||
|
||||
|
||||
class NopTransportWriteCloser(TransportWriteCloser):
|
||||
"""
|
||||
A no-operation TransportWriteCloser implementation.
|
||||
|
||||
This transport does nothing on write and close operations.
|
||||
"""
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
"""
|
||||
No-operation write method.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
No-operation close method.
|
||||
"""
|
||||
pass
|
||||
319
api/core/virtual_environment/providers/docker_daemon_sandbox.py
Normal file
319
api/core/virtual_environment/providers/docker_daemon_sandbox.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import socket
|
||||
import tarfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import docker.errors
|
||||
from docker.models.containers import Container
|
||||
|
||||
import docker
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import VirtualEnvironmentLaunchFailedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.socket_transport import SocketReadCloser, SocketWriteCloser
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
"""
|
||||
EXAMPLE:
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from
|
||||
|
||||
from core.virtual_environment.providers.docker_daemon_sandbox import DockerDaemonEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {
|
||||
# OptionsKey values are optional
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_SOCK: "unix:///var/run/docker.sock",
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_IMAGE: "ubuntu:latest",
|
||||
# DockerDaemonEnvironment.OptionsKey.DOCKER_AGENT_COMMAND
|
||||
#
|
||||
"docker_sock": "unix:///var/run/docker.sock", # optional, default to unix socket
|
||||
"docker_agent_image": "ubuntu:latest", # optional, default to ubuntu:latest
|
||||
"docker_agent_command": "/bin/sh -c 'while true; do sleep 1; done'", # optional, default to None
|
||||
}
|
||||
|
||||
|
||||
environment = DockerDaemonEnvironment(options=options)
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdout, transport_stderr, transport_stdin = environment.execute_command(
|
||||
connection_handle, ["uname", "-a"]
|
||||
)
|
||||
|
||||
print(f"Executed command with PID: {pid}")
|
||||
|
||||
# consume stdout
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DockerDaemonEnvironment(VirtualEnvironment):
|
||||
_WORKING_DIR = "/workspace"
|
||||
_DEAFULT_DOCKER_IMAGE = "ubuntu:latest"
|
||||
_DEFAULT_DOCKER_SOCK = "unix:///var/run/docker.sock"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
DOCKER_SOCK = "docker_sock"
|
||||
DOCKER_IMAGE = "docker_image"
|
||||
DOCKER_COMMAND = "docker_command"
|
||||
|
||||
def construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the Docker daemon virtual environment.
|
||||
"""
|
||||
docker_client = self.get_docker_daemon(
|
||||
docker_sock=options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
|
||||
)
|
||||
|
||||
default_docker_image = options.get(self.OptionsKey.DOCKER_IMAGE, self._DEAFULT_DOCKER_IMAGE)
|
||||
container_command = options.get(self.OptionsKey.DOCKER_COMMAND)
|
||||
|
||||
container = docker_client.containers.run(
|
||||
image=default_docker_image,
|
||||
command=container_command,
|
||||
detach=True,
|
||||
remove=True,
|
||||
stdin_open=True,
|
||||
working_dir=self._WORKING_DIR,
|
||||
environment=dict(environments),
|
||||
)
|
||||
|
||||
# wait for the container to be fully started
|
||||
container.reload()
|
||||
|
||||
if not container.id:
|
||||
raise VirtualEnvironmentLaunchFailedError("Failed to start Docker container for DockerDaemonEnvironment.")
|
||||
|
||||
return Metadata(
|
||||
id=container.id,
|
||||
arch=self._get_container_architecture(container),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_daemon(cls, docker_sock: str) -> docker.DockerClient:
|
||||
"""
|
||||
Get the Docker daemon client.
|
||||
|
||||
NOTE: I guess nobody will use more than 5 different docker sockets in practice....
|
||||
"""
|
||||
return docker.DockerClient(base_url=docker_sock)
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=5)
|
||||
def get_docker_api_client(cls, docker_sock: str) -> docker.APIClient:
|
||||
"""
|
||||
Get the Docker low-level API client.
|
||||
"""
|
||||
return docker.APIClient(base_url=docker_sock)
|
||||
|
||||
def get_docker_sock(self) -> str:
|
||||
"""
|
||||
Get the Docker socket path.
|
||||
"""
|
||||
return self.options.get(self.OptionsKey.DOCKER_SOCK, self._DEFAULT_DOCKER_SOCK)
|
||||
|
||||
@property
|
||||
def _working_dir(self) -> str:
|
||||
"""
|
||||
Get the working directory inside the Docker container.
|
||||
"""
|
||||
return self._WORKING_DIR
|
||||
|
||||
def _get_container(self) -> Container:
|
||||
"""
|
||||
Get the Docker container instance.
|
||||
"""
|
||||
docker_client = self.get_docker_daemon(self.get_docker_sock())
|
||||
return docker_client.containers.get(self.metadata.id)
|
||||
|
||||
def _normalize_relative_path(self, path: str) -> PurePosixPath:
|
||||
parts: list[str] = []
|
||||
for part in PurePosixPath(path).parts:
|
||||
if part in ("", ".", "/"):
|
||||
continue
|
||||
if part == "..":
|
||||
if not parts:
|
||||
raise ValueError("Path escapes the workspace.")
|
||||
parts.pop()
|
||||
continue
|
||||
parts.append(part)
|
||||
return PurePosixPath(*parts)
|
||||
|
||||
def _relative_path(self, path: str) -> PurePosixPath:
|
||||
normalized = self._normalize_relative_path(path)
|
||||
if normalized.parts:
|
||||
return normalized
|
||||
return PurePosixPath()
|
||||
|
||||
def _container_path(self, path: str) -> str:
|
||||
relative = self._relative_path(path)
|
||||
if not relative.parts:
|
||||
return self._working_dir
|
||||
return f"{self._working_dir}/{relative.as_posix()}"
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
container = self._get_container()
|
||||
relative_path = self._relative_path(path)
|
||||
if not relative_path.parts:
|
||||
raise ValueError("Upload path must point to a file within the workspace.")
|
||||
|
||||
payload = content.getvalue()
|
||||
tar_stream = BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
tar_info = tarfile.TarInfo(name=relative_path.as_posix())
|
||||
tar_info.size = len(payload)
|
||||
tar.addfile(tar_info, BytesIO(payload))
|
||||
tar_stream.seek(0)
|
||||
container.put_archive(self._working_dir, tar_stream.read()) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(path)
|
||||
stream, _ = container.get_archive(container_path)
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
members = [member for member in tar.getmembers() if member.isfile()]
|
||||
if not members:
|
||||
return BytesIO()
|
||||
extracted = tar.extractfile(members[0])
|
||||
if extracted is None:
|
||||
return BytesIO()
|
||||
return BytesIO(extracted.read())
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
container = self._get_container()
|
||||
container_path = self._container_path(directory_path)
|
||||
relative_base = self._relative_path(directory_path)
|
||||
try:
|
||||
stream, _ = container.get_archive(container_path)
|
||||
except docker.errors.NotFound:
|
||||
return []
|
||||
tar_stream = BytesIO()
|
||||
for chunk in stream:
|
||||
tar_stream.write(chunk)
|
||||
tar_stream.seek(0)
|
||||
|
||||
files: list[FileState] = []
|
||||
archive_root = PurePosixPath(container_path).name
|
||||
with tarfile.open(fileobj=tar_stream, mode="r:*") as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile():
|
||||
continue
|
||||
member_path = PurePosixPath(member.name)
|
||||
if member_path.parts and member_path.parts[0] == archive_root:
|
||||
member_path = PurePosixPath(*member_path.parts[1:])
|
||||
if not member_path.parts:
|
||||
continue
|
||||
relative_path = relative_base / member_path
|
||||
files.append(
|
||||
FileState(
|
||||
path=relative_path.as_posix(),
|
||||
size=member.size,
|
||||
created_at=int(member.mtime),
|
||||
updated_at=int(member.mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
break
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
# No action needed for Docker exec connections
|
||||
pass
|
||||
|
||||
def release_environment(self) -> None:
|
||||
try:
|
||||
container = self._get_container()
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
try:
|
||||
container.remove(force=True)
|
||||
except docker.errors.NotFound:
|
||||
return
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
container = self._get_container()
|
||||
container_id = container.id
|
||||
if not isinstance(container_id, str) or not container_id:
|
||||
raise RuntimeError("Docker container ID is not available for exec.")
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
exec_info: dict[str, object] = cast(
|
||||
dict[str, object],
|
||||
api_client.exec_create( # pyright: ignore[reportUnknownMemberType] #
|
||||
container_id,
|
||||
cmd=command,
|
||||
stdin=True,
|
||||
stdout=True,
|
||||
stderr=True,
|
||||
tty=False,
|
||||
workdir=self._working_dir,
|
||||
environment=environments,
|
||||
),
|
||||
)
|
||||
|
||||
if not isinstance(exec_info.get("Id"), str):
|
||||
raise RuntimeError("Failed to create Docker exec instance.")
|
||||
|
||||
exec_id: str = str(exec_info.get("Id"))
|
||||
raw_sock: socket.SocketIO = cast(socket.SocketIO, api_client.exec_start(exec_id, socket=True, tty=False)) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
stdin_transport = SocketWriteCloser(raw_sock)
|
||||
stdout_transport = SocketReadCloser(raw_sock)
|
||||
|
||||
return exec_id, stdin_transport, stdout_transport, stdout_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
api_client = self.get_docker_api_client(self.get_docker_sock())
|
||||
inspect: dict[str, object] = cast(dict[str, object], api_client.exec_inspect(pid)) # pyright: ignore[reportUnknownMemberType] #
|
||||
exit_code = inspect.get("ExitCode")
|
||||
if inspect.get("Running") or exit_code is None:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
if not isinstance(exit_code, int):
|
||||
exit_code = None
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=exit_code)
|
||||
|
||||
def _get_container_architecture(self, container: Container) -> Arch:
|
||||
"""
|
||||
Get the architecture of the Docker container.
|
||||
"""
|
||||
return Arch.ARM64
|
||||
|
||||
# container.reload()
|
||||
# arch_str = str(container.attrs["Architecture"])
|
||||
# match arch_str.lower():
|
||||
# case "x86_64" | "amd64":
|
||||
# return Arch.AMD64
|
||||
# case "aarch64" | "arm64":
|
||||
# return Arch.ARM64
|
||||
# case _:
|
||||
# raise ArchNotSupportedError(f"Architecture {arch_str} is not supported in DockerDaemonEnvironment.")
|
||||
253
api/core/virtual_environment/providers/e2b_sandbox.py
Normal file
253
api/core/virtual_environment/providers/e2b_sandbox.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import os
|
||||
import shlex
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from e2b_code_interpreter import Sandbox
|
||||
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError, NotSupportedOperationError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.transport import (
|
||||
NopTransportWriteCloser,
|
||||
TransportReadCloser,
|
||||
TransportWriteCloser,
|
||||
)
|
||||
|
||||
"""
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.providers.e2b_sandbox import E2BEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {
|
||||
E2BEnvironment.OptionsKey.API_KEY: "?????????",
|
||||
E2BEnvironment.OptionsKey.E2B_DEFAULT_TEMPLATE: "code-interpreter-v1",
|
||||
E2BEnvironment.OptionsKey.E2B_LIST_FILE_DEPTH: 2,
|
||||
E2BEnvironment.OptionsKey.E2B_API_URL: "https://api.e2b.app",
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# environment = DockerDaemonEnvironment(options=options)
|
||||
# environment = LocalVirtualEnvironment(options=options)
|
||||
environment = E2BEnvironment(options=options)
|
||||
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
|
||||
connection_handle, ["uname", "-a"]
|
||||
)
|
||||
|
||||
logger.info("Executed command with PID: %s", pid)
|
||||
|
||||
# consume stdout
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class E2BEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
E2B virtual environment provider.
|
||||
"""
|
||||
|
||||
_WORKDIR = "/home/user"
|
||||
_E2B_API_URL = "https://api.e2b.app"
|
||||
|
||||
class OptionsKey(StrEnum):
|
||||
API_KEY = "api_key"
|
||||
E2B_LIST_FILE_DEPTH = "e2b_list_file_depth"
|
||||
E2B_DEFAULT_TEMPLATE = "e2b_default_template"
|
||||
E2B_API_URL = "e2b_api_url"
|
||||
|
||||
class StoreKey(StrEnum):
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
def construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct a new E2B virtual environment.
|
||||
"""
|
||||
# TODO: add Dify as the user agent
|
||||
sandbox = Sandbox.create(
|
||||
template=options.get(self.OptionsKey.E2B_DEFAULT_TEMPLATE, "code-interpreter-v1"),
|
||||
api_key=options.get(self.OptionsKey.API_KEY, ""),
|
||||
api_url=options.get(self.OptionsKey.E2B_API_URL, self._E2B_API_URL),
|
||||
envs=dict(environments),
|
||||
)
|
||||
info = sandbox.get_info(api_key=options.get(self.OptionsKey.API_KEY, ""))
|
||||
output = sandbox.commands.run("uname -m").stdout.strip()
|
||||
|
||||
return Metadata(
|
||||
id=info.sandbox_id,
|
||||
arch=self._convert_architecture(output),
|
||||
store={
|
||||
self.StoreKey.SANDBOX: sandbox,
|
||||
},
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the E2B virtual environment.
|
||||
"""
|
||||
if not Sandbox.kill(api_key=self.api_key, sandbox_id=self.metadata.id):
|
||||
raise Exception(f"Failed to release E2B sandbox with ID: {self.metadata.id}")
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the E2B virtual environment.
|
||||
"""
|
||||
return ConnectionHandle(id=uuid4().hex)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the E2B virtual environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the E2B virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to upload the file to.
|
||||
content (BytesIO): The content of the file.
|
||||
"""
|
||||
path = os.path.join(self._WORKDIR, path.lstrip("/"))
|
||||
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
sandbox.files.write(path, content) # pyright: ignore[reportUnknownMemberType] #
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the E2B virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to download the file from.
|
||||
Returns:
|
||||
BytesIO: The content of the file.
|
||||
"""
|
||||
path = os.path.join(self._WORKDIR, path.lstrip("/"))
|
||||
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
content = sandbox.files.read(path)
|
||||
return BytesIO(content.encode())
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the E2B virtual environment.
|
||||
"""
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
directory_path = os.path.join(self._WORKDIR, directory_path.lstrip("/"))
|
||||
files_info = sandbox.files.list(directory_path, depth=self.options.get(self.OptionsKey.E2B_LIST_FILE_DEPTH, 3))
|
||||
return [
|
||||
FileState(
|
||||
path=os.path.relpath(file_info.path, self._WORKDIR),
|
||||
size=file_info.size,
|
||||
created_at=int(file_info.modified_time.timestamp()),
|
||||
updated_at=int(file_info.modified_time.timestamp()),
|
||||
)
|
||||
for file_info in files_info
|
||||
]
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the E2B virtual environment.
|
||||
|
||||
STDIN is not yet supported. E2B's API is such a terrible mess... to support it may lead a bad design.
|
||||
as a result we leave it for future improvement.
|
||||
"""
|
||||
sandbox: Sandbox = self.metadata.store[self.StoreKey.SANDBOX]
|
||||
stdout_stream = QueueTransportReadCloser()
|
||||
stderr_stream = QueueTransportReadCloser()
|
||||
|
||||
threading.Thread(
|
||||
target=self._cmd_thread,
|
||||
args=(sandbox, command, environments, stdout_stream, stderr_stream),
|
||||
).start()
|
||||
|
||||
return (
|
||||
"N/A",
|
||||
NopTransportWriteCloser(), # stdin not supported yet
|
||||
stdout_stream,
|
||||
stderr_stream,
|
||||
)
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Nop, E2B does not support getting command status yet.
|
||||
"""
|
||||
raise NotSupportedOperationError("E2B does not support getting command status yet.")
|
||||
|
||||
def _cmd_thread(
|
||||
self,
|
||||
sandbox: Sandbox,
|
||||
command: list[str],
|
||||
environments: Mapping[str, str] | None,
|
||||
stdout_stream: QueueTransportReadCloser,
|
||||
stderr_stream: QueueTransportReadCloser,
|
||||
) -> None:
|
||||
""" """
|
||||
stdout_stream_write_handler = stdout_stream.get_write_handler()
|
||||
stderr_stream_write_handler = stderr_stream.get_write_handler()
|
||||
|
||||
try:
|
||||
sandbox.commands.run(
|
||||
cmd=shlex.join(command),
|
||||
envs=dict(environments or {}),
|
||||
# stdin=True,
|
||||
on_stdout=lambda data: stdout_stream_write_handler.write(data.encode()),
|
||||
on_stderr=lambda data: stderr_stream_write_handler.write(data.encode()),
|
||||
)
|
||||
finally:
|
||||
# Close the write handlers to signal EOF
|
||||
stdout_stream.close()
|
||||
stderr_stream.close()
|
||||
|
||||
@cached_property
|
||||
def api_key(self) -> str:
|
||||
"""
|
||||
Get the API key for the E2B environment.
|
||||
"""
|
||||
return self.options.get(self.OptionsKey.API_KEY, "")
|
||||
|
||||
def _convert_architecture(self, arch_str: str) -> Arch:
|
||||
"""
|
||||
Convert architecture string to standard format.
|
||||
"""
|
||||
arch_map = {
|
||||
"x86_64": Arch.AMD64,
|
||||
"aarch64": Arch.ARM64,
|
||||
"armv7l": Arch.ARM64,
|
||||
"arm64": Arch.ARM64,
|
||||
"amd64": Arch.AMD64,
|
||||
"arm64v8": Arch.ARM64,
|
||||
"arm64v7": Arch.ARM64,
|
||||
}
|
||||
if arch_str in arch_map:
|
||||
return arch_map[arch_str]
|
||||
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch_str}")
|
||||
@@ -0,0 +1,274 @@
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from io import BytesIO
|
||||
from platform import machine
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.virtual_environment.__base.entities import Arch, CommandStatus, ConnectionHandle, FileState, Metadata
|
||||
from core.virtual_environment.__base.exec import ArchNotSupportedError
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeWriteCloser
|
||||
from core.virtual_environment.channel.transport import TransportReadCloser, TransportWriteCloser
|
||||
|
||||
"""
|
||||
USAGE:
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
options: Mapping[str, Any] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
environment = LocalVirtualEnvironment(options=options)
|
||||
|
||||
connection_handle = environment.establish_connection()
|
||||
|
||||
pid, transport_stdin, transport_stdout, transport_stderr = environment.execute_command(
|
||||
connection_handle,
|
||||
["sh", "-lc", "for i in 1 2 3 4 5; do date '+%F %T'; sleep 1; done"],
|
||||
)
|
||||
|
||||
logger.info("Executed command with PID: %s", pid)
|
||||
|
||||
# consume stdout
|
||||
while True:
|
||||
try:
|
||||
output = transport_stdout.read(1024)
|
||||
except TransportEOFError:
|
||||
logger.info("End of stdout reached")
|
||||
break
|
||||
|
||||
logger.info("Command output: %s", output.decode().strip())
|
||||
|
||||
|
||||
environment.release_connection(connection_handle)
|
||||
environment.release_environment()
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class LocalVirtualEnvironment(VirtualEnvironment):
|
||||
"""
|
||||
Local virtual environment provider without isolation.
|
||||
|
||||
WARNING: This provider does not provide any isolation. It's only suitable for development and testing purposes.
|
||||
NEVER USE IT IN PRODUCTION ENVIRONMENTS.
|
||||
"""
|
||||
|
||||
def construct_environment(self, options: Mapping[str, Any], environments: Mapping[str, str]) -> Metadata:
|
||||
"""
|
||||
Construct the local virtual environment.
|
||||
|
||||
Under local without isolation, this method simply create a path for the environment and return the metadata.
|
||||
"""
|
||||
id = uuid4().hex
|
||||
working_path = os.path.join(self._base_working_path, id)
|
||||
os.makedirs(working_path, exist_ok=True)
|
||||
return Metadata(
|
||||
id=id,
|
||||
arch=self._get_os_architecture(),
|
||||
)
|
||||
|
||||
def release_environment(self) -> None:
|
||||
"""
|
||||
Release the local virtual environment.
|
||||
|
||||
Just simply remove the working directory.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
if os.path.exists(working_path):
|
||||
os.rmdir(working_path)
|
||||
|
||||
def upload_file(self, path: str, content: BytesIO) -> None:
|
||||
"""
|
||||
Upload a file to the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to upload the file to.
|
||||
content (BytesIO): The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||
pathlib.Path(full_path).write_bytes(content.getbuffer())
|
||||
|
||||
def download_file(self, path: str) -> BytesIO:
|
||||
"""
|
||||
Download a file from the local virtual environment.
|
||||
|
||||
Args:
|
||||
path (str): The path to download the file from.
|
||||
Returns:
|
||||
BytesIO: The content of the file.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_path = os.path.join(working_path, path)
|
||||
content = pathlib.Path(full_path).read_bytes()
|
||||
return BytesIO(content)
|
||||
|
||||
def list_files(self, directory_path: str, limit: int) -> Sequence[FileState]:
|
||||
"""
|
||||
List files in a directory of the local virtual environment.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
full_directory_path = os.path.join(working_path, directory_path)
|
||||
files: list[FileState] = []
|
||||
for root, _, filenames in os.walk(full_directory_path):
|
||||
for filename in filenames:
|
||||
if len(files) >= limit:
|
||||
break
|
||||
file_path = os.path.relpath(os.path.join(root, filename), working_path)
|
||||
state = os.stat(os.path.join(root, filename))
|
||||
files.append(
|
||||
FileState(
|
||||
path=file_path,
|
||||
size=state.st_size,
|
||||
created_at=int(state.st_ctime),
|
||||
updated_at=int(state.st_mtime),
|
||||
)
|
||||
)
|
||||
if len(files) >= limit:
|
||||
# break the outer loop as well
|
||||
return files
|
||||
|
||||
return files
|
||||
|
||||
def establish_connection(self) -> ConnectionHandle:
|
||||
"""
|
||||
Establish a connection to the local virtual environment.
|
||||
"""
|
||||
return ConnectionHandle(
|
||||
id=uuid4().hex,
|
||||
)
|
||||
|
||||
def release_connection(self, connection_handle: ConnectionHandle) -> None:
|
||||
"""
|
||||
Release the connection to the local virtual environment.
|
||||
"""
|
||||
# No action needed for local without isolation
|
||||
pass
|
||||
|
||||
def execute_command(
|
||||
self, connection_handle: ConnectionHandle, command: list[str], environments: Mapping[str, str] | None = None
|
||||
) -> tuple[str, TransportWriteCloser, TransportReadCloser, TransportReadCloser]:
|
||||
"""
|
||||
Execute a command in the local virtual environment.
|
||||
|
||||
Args:
|
||||
connection_handle (ConnectionHandle): The connection handle.
|
||||
command (list[str]): The command to execute.
|
||||
"""
|
||||
working_path = self.get_working_path()
|
||||
stdin_read_fd, stdin_write_fd = os.pipe()
|
||||
stdout_read_fd, stdout_write_fd = os.pipe()
|
||||
stderr_read_fd, stderr_write_fd = os.pipe()
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdin=stdin_read_fd,
|
||||
stdout=stdout_write_fd,
|
||||
stderr=stderr_write_fd,
|
||||
cwd=working_path,
|
||||
close_fds=True,
|
||||
env=environments,
|
||||
)
|
||||
except Exception:
|
||||
# Clean up file descriptors if process creation fails
|
||||
for fd in (
|
||||
stdin_read_fd,
|
||||
stdin_write_fd,
|
||||
stdout_read_fd,
|
||||
stdout_write_fd,
|
||||
stderr_read_fd,
|
||||
stderr_write_fd,
|
||||
):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
# Close unused fds in the parent process
|
||||
os.close(stdin_read_fd)
|
||||
os.close(stdout_write_fd)
|
||||
os.close(stderr_write_fd)
|
||||
|
||||
# Create PipeTransport instances for stdin, stdout, and stderr
|
||||
stdin_transport = PipeWriteCloser(w_fd=stdin_write_fd)
|
||||
stdout_transport = PipeReadCloser(r_fd=stdout_read_fd)
|
||||
stderr_transport = PipeReadCloser(r_fd=stderr_read_fd)
|
||||
|
||||
# Return the process ID and file descriptors for stdin, stdout, and stderr
|
||||
return str(process.pid), stdin_transport, stdout_transport, stderr_transport
|
||||
|
||||
def get_command_status(self, connection_handle: ConnectionHandle, pid: str) -> CommandStatus:
|
||||
"""
|
||||
Docstring for get_command_status
|
||||
|
||||
:param self: Description
|
||||
:param connection_handle: Description
|
||||
:type connection_handle: ConnectionHandle
|
||||
:param pid: Description
|
||||
:type pid: int
|
||||
:return: Description
|
||||
:rtype: CommandStatus
|
||||
"""
|
||||
pid_int = int(pid)
|
||||
try:
|
||||
retcode = os.waitpid(pid_int, os.WNOHANG)[1]
|
||||
if retcode == 0:
|
||||
return CommandStatus(status=CommandStatus.Status.RUNNING, exit_code=None)
|
||||
else:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=retcode)
|
||||
except ChildProcessError:
|
||||
return CommandStatus(status=CommandStatus.Status.COMPLETED, exit_code=None)
|
||||
|
||||
def _get_os_architecture(self) -> Arch:
|
||||
"""
|
||||
Get the operating system architecture.
|
||||
|
||||
Returns:
|
||||
Arch: The operating system architecture.
|
||||
"""
|
||||
|
||||
arch = machine()
|
||||
match arch.lower():
|
||||
case "x86_64" | "amd64":
|
||||
return Arch.AMD64
|
||||
case "aarch64" | "arm64":
|
||||
return Arch.ARM64
|
||||
case _:
|
||||
raise ArchNotSupportedError(f"Unsupported architecture: {arch}")
|
||||
|
||||
@cached_property
|
||||
def _base_working_path(self) -> str:
|
||||
"""
|
||||
Get the base working path for the local virtual environment.
|
||||
|
||||
Args:
|
||||
options (Mapping[str, Any]): Options for requesting the virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The base working path.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
return self.options.get("base_working_path", os.path.join(cwd, "local_virtual_environments"))
|
||||
|
||||
def get_working_path(self) -> str:
|
||||
"""
|
||||
Get the working path for the local virtual environment.
|
||||
|
||||
Returns:
|
||||
str: The working path.
|
||||
"""
|
||||
return os.path.join(self._base_working_path, self.metadata.id)
|
||||
@@ -88,11 +88,13 @@ dependencies = [
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.0",
|
||||
"packaging~=23.2",
|
||||
"packaging==24.1",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"docker>=7.1.0",
|
||||
"e2b-code-interpreter>=2.4.1",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.channel.exec import TransportEOFError
|
||||
from core.virtual_environment.channel.pipe_transport import PipeReadCloser, PipeTransport, PipeWriteCloser
|
||||
from core.virtual_environment.channel.queue_transport import QueueTransportReadCloser
|
||||
from core.virtual_environment.channel.socket_transport import SocketReadCloser, SocketTransport, SocketWriteCloser
|
||||
|
||||
|
||||
def _close_socket(sock: socket.socket) -> None:
|
||||
try:
|
||||
sock.close()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def test_queue_transport_reads_across_chunks() -> None:
|
||||
transport = QueueTransportReadCloser()
|
||||
writer = transport.get_write_handler()
|
||||
writer.write(b"hello")
|
||||
writer.write(b"world")
|
||||
|
||||
data = transport.read(8)
|
||||
assert data == b"hellowor"
|
||||
assert transport.read(2) == b"ld"
|
||||
|
||||
|
||||
def test_queue_transport_close_then_read_raises() -> None:
|
||||
transport = QueueTransportReadCloser()
|
||||
transport.close()
|
||||
|
||||
with pytest.raises(TransportEOFError):
|
||||
transport.read(1)
|
||||
|
||||
|
||||
def test_queue_transport_close_twice_raises() -> None:
|
||||
transport = QueueTransportReadCloser()
|
||||
transport.close()
|
||||
|
||||
with pytest.raises(TransportEOFError):
|
||||
transport.close()
|
||||
|
||||
|
||||
def test_pipe_transport_roundtrip() -> None:
|
||||
r_fd, w_fd = os.pipe()
|
||||
transport = PipeTransport(r_fd, w_fd)
|
||||
try:
|
||||
transport.write(b"ping")
|
||||
assert transport.read(4) == b"ping"
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
|
||||
def test_pipe_read_closer_eof_raises() -> None:
|
||||
r_fd, w_fd = os.pipe()
|
||||
os.close(w_fd)
|
||||
reader = PipeReadCloser(r_fd)
|
||||
try:
|
||||
with pytest.raises(TransportEOFError):
|
||||
reader.read(1)
|
||||
finally:
|
||||
reader.close()
|
||||
|
||||
|
||||
def test_pipe_write_closer_eof_raises() -> None:
|
||||
r_fd, w_fd = os.pipe()
|
||||
os.close(r_fd)
|
||||
writer = PipeWriteCloser(w_fd)
|
||||
try:
|
||||
with pytest.raises(TransportEOFError):
|
||||
writer.write(b"x")
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
def test_socket_transport_roundtrip() -> None:
|
||||
sock_a, sock_b = socket.socketpair()
|
||||
sock_a_io = sock_a.makefile("rwb", buffering=0)
|
||||
sock_b_io = sock_b.makefile("rwb", buffering=0)
|
||||
transport_a = SocketTransport(sock_a_io)
|
||||
transport_b = SocketTransport(sock_b_io)
|
||||
try:
|
||||
transport_a.write(b"x")
|
||||
assert transport_b.read(1) == b"x"
|
||||
finally:
|
||||
transport_a.close()
|
||||
transport_b.close()
|
||||
_close_socket(sock_a)
|
||||
_close_socket(sock_b)
|
||||
|
||||
|
||||
def test_socket_read_closer_eof_raises() -> None:
|
||||
sock_a, sock_b = socket.socketpair()
|
||||
sock_a_io = sock_a.makefile("rb", buffering=0)
|
||||
reader = SocketReadCloser(sock_a_io)
|
||||
try:
|
||||
sock_b.close()
|
||||
with pytest.raises(TransportEOFError):
|
||||
reader.read(1)
|
||||
finally:
|
||||
reader.close()
|
||||
_close_socket(sock_a)
|
||||
|
||||
|
||||
def test_socket_write_closer_writes() -> None:
|
||||
sock_a, sock_b = socket.socketpair()
|
||||
sock_a_io = sock_a.makefile("wb", buffering=0)
|
||||
sock_b_io = sock_b.makefile("rb", buffering=0)
|
||||
writer = SocketWriteCloser(sock_a_io)
|
||||
reader = SocketReadCloser(sock_b_io)
|
||||
try:
|
||||
writer.write(b"y")
|
||||
assert reader.read(1) == b"y"
|
||||
finally:
|
||||
writer.close()
|
||||
reader.close()
|
||||
_close_socket(sock_a)
|
||||
_close_socket(sock_b)
|
||||
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from core.virtual_environment.providers import local_without_isolation
|
||||
from core.virtual_environment.providers.local_without_isolation import LocalVirtualEnvironment
|
||||
|
||||
|
||||
def _read_all(fd: int) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = os.read(fd, 4096)
|
||||
if not data:
|
||||
break
|
||||
chunks.append(data)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _close_fds(*fds: int) -> None:
|
||||
for fd in fds:
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> LocalVirtualEnvironment:
|
||||
monkeypatch.setattr(local_without_isolation, "machine", lambda: "x86_64")
|
||||
return LocalVirtualEnvironment({"base_working_path": str(tmp_path)})
|
||||
|
||||
|
||||
def test_construct_environment_creates_working_path(local_env: LocalVirtualEnvironment):
|
||||
working_path = local_env.get_working_path()
|
||||
assert local_env.metadata.id
|
||||
assert os.path.isdir(working_path)
|
||||
|
||||
|
||||
def test_upload_download_roundtrip(local_env: LocalVirtualEnvironment):
|
||||
content = BytesIO(b"payload")
|
||||
local_env.upload_file("nested/file.txt", content)
|
||||
|
||||
downloaded = local_env.download_file("nested/file.txt")
|
||||
|
||||
assert downloaded.getvalue() == b"payload"
|
||||
|
||||
|
||||
def test_list_files_respects_limit(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("dir/file_a.txt", BytesIO(b"a"))
|
||||
local_env.upload_file("file_b.txt", BytesIO(b"b"))
|
||||
|
||||
all_files = local_env.list_files("", limit=10)
|
||||
all_paths = {state.path for state in all_files}
|
||||
|
||||
assert os.path.join("dir", "file_a.txt") in all_paths
|
||||
assert "file_b.txt" in all_paths
|
||||
|
||||
limited_files = local_env.list_files("", limit=1)
|
||||
assert len(limited_files) == 1
|
||||
|
||||
|
||||
def test_execute_command_uses_working_directory(local_env: LocalVirtualEnvironment):
|
||||
local_env.upload_file("message.txt", BytesIO(b"hello"))
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "cat message.txt"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
|
||||
assert stdout == b"hello"
|
||||
assert stderr == b""
|
||||
|
||||
|
||||
def test_execute_command_pipes_stdio(local_env: LocalVirtualEnvironment):
|
||||
connection = local_env.establish_connection()
|
||||
command = ["/bin/sh", "-c", "tr a-z A-Z < /dev/stdin; printf ERR >&2"]
|
||||
|
||||
pid, stdin_fd, stdout_fd, stderr_fd = local_env.execute_command(connection, command)
|
||||
|
||||
try:
|
||||
os.write(stdin_fd, b"abc")
|
||||
os.close(stdin_fd)
|
||||
if hasattr(os, "waitpid"):
|
||||
os.waitpid(pid, 0)
|
||||
stdout = _read_all(stdout_fd)
|
||||
stderr = _read_all(stderr_fd)
|
||||
finally:
|
||||
_close_fds(stdin_fd, stdout_fd, stderr_fd)
|
||||
|
||||
assert stdout == b"ABC"
|
||||
assert stderr == b"ERR"
|
||||
4719
api/uv.lock
generated
4719
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user