Skip to content

Commit

Permalink
Merge branch 'main' into dutlink-test-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCao authored Feb 3, 2025
2 parents 82981eb + ee4720b commit 3371b1d
Show file tree
Hide file tree
Showing 21 changed files with 261 additions and 222 deletions.
10 changes: 4 additions & 6 deletions __templates__/driver/jumpstarter_driver/driver.py.tmpl
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging
from dataclasses import dataclass

from jumpstarter.driver import Driver, export

logger = logging.getLogger(__name__)

@dataclass(kw_only=True)
class ${DRIVER_CLASS}(Driver):
"""${DRIVE_NAME} driver for Jumpstarter"""
Expand All @@ -13,7 +10,8 @@ class ${DRIVER_CLASS}(Driver):
some_other_config: int = 69

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()
# some initialization here.

@classmethod
Expand All @@ -22,10 +20,10 @@ class ${DRIVER_CLASS}(Driver):

@export
def method1(self):
logger.info("Method1 called")
self.logger.info("Method1 called")
return "method1 response"

@export
def method2(self):
logger.info("Method2 called")
self.logger.info("Method2 called")
return "method2 response"
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ class CanClient(DriverClient, can.BusABC):
"""

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()

self._periodic_tasks: List[_SelfRemovingCyclicTask] = []
self._filters = None
self._is_shutdown: bool = False

super().__post_init__()

@property
@validate_call(validate_return=True)
def state(self) -> can.BusState:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def client(cls) -> str:
return "jumpstarter_driver_can.client.CanClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.bus = can.Bus(channel=self.channel, interface=self.interface)

@export
Expand Down Expand Up @@ -195,7 +197,9 @@ def client(cls) -> str:
return "jumpstarter_driver_can.client.IsoTpClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.bus = can.Bus(channel=self.channel, interface=self.interface)
self.notifier = can.Notifier(self.bus, [])
self.stack = isotp.NotifierBasedCanStack(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import logging
import os
import time
from collections.abc import AsyncGenerator
Expand All @@ -19,8 +18,6 @@

from jumpstarter.driver import Driver, export

log = logging.getLogger(__name__)


@dataclass(kw_only=True)
class DutlinkConfig:
Expand All @@ -32,10 +29,13 @@ class DutlinkConfig:
tty: str | None = field(init=False, default=None)

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()

for dev in usb.core.find(idVendor=0x2B23, idProduct=0x1012, find_all=True):
serial = usb.util.get_string(dev, dev.iSerialNumber)
if serial == self.serial or self.serial is None:
log.debug(f"found dutlink board with serial {serial}")
self.logger.debug(f"found dutlink board with serial {serial}")

self.serial = serial
self.dev = dev
Expand Down Expand Up @@ -78,30 +78,30 @@ def control(self, direction, ty, actions, action, value):

if direction == usb.ENDPOINT_IN:
str_value = bytes(res).decode("utf-8")
log.debug(
"ctrl_transfer result: %s",
)
self.logger.debug("ctrl_transfer result: %s", str_value)
return str_value


@dataclass(kw_only=True)
class DutlinkSerial(DutlinkConfig, PySerial):
url: str | None = field(init=False, default=None)

class DutlinkSerialConfig(DutlinkConfig, Driver):
def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.url = self.tty

super(PySerial, self).__post_init__()

@dataclass(kw_only=True)
class DutlinkSerial(PySerial, DutlinkSerialConfig):
url: str | None = field(init=False, default=None)


@dataclass(kw_only=True)
class DutlinkPower(DutlinkConfig, PowerInterface, Driver):
last_action: str | None = field(default=None)

def control(self, action):
log.debug(f"power control: {action}")
self.logger.debug(f"power control: {action}")
if self.last_action == action:
return

Expand Down Expand Up @@ -160,7 +160,7 @@ class DutlinkStorageMux(DutlinkConfig, StorageMuxInterface, Driver):
storage_device: str

def control(self, action):
log.debug(f"storage control: {action}")
self.logger.debug(f"storage control: {action}")
return super().control(
usb.ENDPOINT_OUT,
0x02,
Expand Down Expand Up @@ -190,9 +190,9 @@ def off(self):
async def wait_for_storage_device(self):
with fail_after(20):
while True:
log.debug(f"waiting for storage device {self.storage_device}")
self.logger.debug(f"waiting for storage device {self.storage_device}")
if os.path.exists(self.storage_device):
log.debug(f"storage device {self.storage_device} is ready")
self.logger.debug(f"storage device {self.storage_device} is ready")
# https://stackoverflow.com/a/2774125
fd = os.open(self.storage_device, os.O_WRONLY)
try:
Expand All @@ -213,7 +213,7 @@ async def write(self, src: str):
async for chunk in res:
await stream.send(chunk)
if total_bytes > next_print:
log.debug(f"{self.storage_device} written {total_bytes / (1024 * 1024)} MB")
self.logger.debug(f"{self.storage_device} written {total_bytes / (1024 * 1024)} MB")
next_print += 50 * 1024 * 1024
total_bytes += len(chunk)

Expand Down Expand Up @@ -252,18 +252,20 @@ class Dutlink(DutlinkConfig, CompositeInterface, Driver):
"""

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.children["power"] = DutlinkPower(serial=self.serial, timeout_s=self.timeout_s)
self.children["storage"] = DutlinkStorageMux(serial=self.serial, storage_device=self.storage_device,
timeout_s=self.timeout_s)
self.children["storage"] = DutlinkStorageMux(
serial=self.serial, storage_device=self.storage_device, timeout_s=self.timeout_s
)

# if an alternate serial port has been requested, use it
if self.alternate_console is not None:
try:
self.children["console"] = PySerial(url=self.alternate_console, baudrate=self.baudrate)
except SerialException:
log.info(
self.logger.info(
f"failed to open alternate console {self.alternate_console} but trying to power on the target once"
)
self.children["power"].on()
Expand Down
64 changes: 31 additions & 33 deletions packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -10,8 +9,6 @@

from jumpstarter.driver import Driver, export

logger = logging.getLogger(__name__)


class HttpServerError(Exception):
"""Base exception for HTTP server errors"""
Expand All @@ -21,38 +18,39 @@ class FileWriteError(HttpServerError):
"""Exception raised when file writing fails"""


def get_default_ip():
try:
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
logger.warning("Could not determine default IP address, falling back to 0.0.0.0")
return "0.0.0.0"


@dataclass(kw_only=True)
class HttpServer(Driver):
"""HTTP Server driver for Jumpstarter"""

root_dir: str = "/var/www"
host: str = field(default_factory=get_default_ip)
host: str = field(default=None)
port: int = 8080
app: web.Application = field(init=False, default_factory=web.Application)
runner: Optional[web.AppRunner] = field(init=False, default=None)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

os.makedirs(self.root_dir, exist_ok=True)
self.app.router.add_routes(
[
web.get("/{filename}", self.get_file),
]
)
if self.host is None:
self.host = self.get_default_ip()

def get_default_ip(self):
try:
import socket

with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception:
self.logger.warning("Could not determine default IP address, falling back to 0.0.0.0")
return "0.0.0.0"

@classmethod
def client(cls) -> str:
Expand Down Expand Up @@ -86,11 +84,11 @@ async def put_file(self, filename: str, src_stream) -> str:
async for chunk in src:
await dst.send(chunk)

logger.info(f"File '{filename}' written to '{file_path}'")
self.logger.info(f"File '{filename}' written to '{file_path}'")
return f"{self.get_url()}/{filename}"

except Exception as e:
logger.error(f"Failed to upload file '{filename}': {e}")
self.logger.error(f"Failed to upload file '{filename}': {e}")
raise FileWriteError(f"Failed to upload file '{filename}': {e}") from e

@export
Expand All @@ -112,10 +110,10 @@ async def delete_file(self, filename: str) -> str:
raise HttpServerError(f"File '{filename}' does not exist.")
try:
file_path.unlink()
logger.info(f"File '{filename}' has been deleted.")
self.logger.info(f"File '{filename}' has been deleted.")
return filename
except Exception as e:
logger.error(f"Failed to delete file '{filename}': {e}")
self.logger.error(f"Failed to delete file '{filename}': {e}")
raise HttpServerError(f"Failed to delete file '{filename}': {e}") from e

async def get_file(self, request) -> web.FileResponse:
Expand All @@ -134,9 +132,9 @@ async def get_file(self, request) -> web.FileResponse:
filename = request.match_info["filename"]
file_path = os.path.join(self.root_dir, filename)
if not os.path.isfile(file_path):
logger.warning(f"File not found: {file_path}")
self.logger.warning(f"File not found: {file_path}")
raise web.HTTPNotFound(text=f"File '{filename}' not found.")
logger.info(f"Serving file: {file_path}")
self.logger.info(f"Serving file: {file_path}")
return web.FileResponse(file_path)

@export
Expand All @@ -155,7 +153,7 @@ def list_files(self) -> list[str]:
files = [f for f in files if os.path.isfile(os.path.join(self.root_dir, f))]
return files
except Exception as e:
logger.error(f"Failed to list files: {e}")
self.logger.error(f"Failed to list files: {e}")
raise HttpServerError(f"Failed to list files: {e}") from e

@export
Expand All @@ -167,7 +165,7 @@ async def start(self):
HttpServerError: If the server fails to start.
"""
if self.runner is not None:
logger.warning("HTTP server is already running.")
self.logger.warning("HTTP server is already running.")
return

self.runner = web.AppRunner(self.app)
Expand All @@ -176,7 +174,7 @@ async def start(self):

site = web.TCPSite(self.runner, self.host, self.port)
await site.start()
logger.info(f"HTTP server started at http://{self.host}:{self.port}")
self.logger.info(f"HTTP server started at http://{self.host}:{self.port}")

@export
async def stop(self):
Expand All @@ -187,11 +185,11 @@ async def stop(self):
HttpServerError: If the server fails to stop.
"""
if self.runner is None:
logger.warning("HTTP server is not running.")
self.logger.warning("HTTP server is not running.")
return

await self.runner.cleanup()
logger.info("HTTP server stopped.")
self.logger.info("HTTP server stopped.")
self.runner = None

@export
Expand Down Expand Up @@ -230,7 +228,7 @@ def close(self):
if anyio.get_current_task():
anyio.from_thread.run(self._async_cleanup)
except Exception as e:
logger.warning(f"HTTP server cleanup failed synchronously: {e}")
self.logger.warning(f"HTTP server cleanup failed synchronously: {e}")
self.runner = None
super().close()

Expand All @@ -239,6 +237,6 @@ async def _async_cleanup(self):
if self.runner:
await self.runner.shutdown()
await self.runner.cleanup()
logger.info("HTTP server cleanup completed asynchronously.")
self.logger.info("HTTP server cleanup completed asynchronously.")
except Exception as e:
logger.error(f"HTTP server cleanup failed asynchronously: {e}")
self.logger.error(f"HTTP server cleanup failed asynchronously: {e}")
Loading

0 comments on commit 3371b1d

Please sign in to comment.