Skip to content

Commit

Permalink
Merge pull request #241 from bennyz/tftp-checksum
Browse files Browse the repository at this point in the history
tftp: add checksum validation
  • Loading branch information
mangelajo authored Feb 3, 2025
2 parents 675b96c + 42c92d2 commit ee4720b
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 97 deletions.
62 changes: 24 additions & 38 deletions packages/jumpstarter-driver-tftp/examples/tftp_test.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,37 @@
import logging
import time

import pytest
from jumpstarter_driver_tftp.driver import FileNotFound, TftpError
from jumpstarter_testing.pytest import JumpstarterTest

log = logging.getLogger(__name__)


class TestResource(JumpstarterTest):
filter_labels = {"board": "rpi4"}

@pytest.fixture()
def test_tftp_upload(self, client):
def setup_tftp(self, client):
# Move the setup code to a fixture
client.tftp.start()
yield client
client.tftp.stop()

def test_tftp_operations(self, setup_tftp):
client = setup_tftp
test_file = "test.bin"

# Create test file
with open(test_file, "wb") as f:
f.write(b"Hello from TFTP streaming test!")

try:
client.tftp.start()
print("TFTP server started")

time.sleep(1)

test_file = "test.bin"
with open(test_file, "wb") as f:
f.write(b"Hello from TFTP streaming test!")

try:
client.tftp.put_local_file(test_file)
print(f"Successfully uploaded {test_file}")

files = client.tftp.list_files()
print(f"Files in TFTP root: {files}")

if test_file in files:
client.tftp.delete_file(test_file)
print(f"Successfully deleted {test_file}")
else:
print(f"Warning: {test_file} not found in TFTP root")

except TftpError as e:
print(f"TFTP operation failed: {e}")
except FileNotFound as e:
print(f"File not found: {e}")

except Exception as e:
print(f"Error: {e}")
finally:
try:
client.tftp.stop()
print("TFTP server stopped")
except Exception as e:
print(f"Error stopping server: {e}")
# Test upload
client.tftp.put_local_file(test_file)
assert test_file in client.tftp.list_files()

# Test delete
client.tftp.delete_file(test_file)
assert test_file not in client.tftp.list_files()

except (TftpError, FileNotFound) as e:
pytest.fail(f"Test failed: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CHUNK_SIZE = 1024 * 1024 * 4 # 4MB
48 changes: 25 additions & 23 deletions packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
from dataclasses import dataclass
from pathlib import Path

from jumpstarter_driver_opendal.adapter import OpendalAdapter
from opendal import Operator

from . import CHUNK_SIZE
from jumpstarter.client import DriverClient


Expand Down Expand Up @@ -46,37 +48,30 @@ def list_files(self) -> list[str]:
return self.call("list_files")

def put_file(self, operator: Operator, path: str):
"""
Upload a file to the TFTP server using an OpenDAL operator
filename = Path(path).name
client_checksum = self._compute_checksum(operator, path)

Args:
operator (Operator): OpenDAL operator for accessing the source storage
path (str): Path to the file in the source storage system
if self.call("check_file_checksum", filename, client_checksum):
self.logger.info(f"Skipping upload of identical file: {filename}")
return filename

Returns:
str: Name of the uploaded file
"""
filename = Path(path).name
with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle:
return self.call("put_file", filename, handle)
return self.call("put_file", filename, handle, client_checksum)

def put_local_file(self, filepath: str):
"""
Upload a file from the local filesystem to the TFTP server
Note: this doesn't use TFTP to upload.
absolute = Path(filepath).resolve()
filename = absolute.name

Args:
filepath (str): Path to the local file to upload
operator = Operator("fs", root="/")
client_checksum = self._compute_checksum(operator, str(absolute))

Returns:
str: Name of the uploaded file
if self.call("check_file_checksum", filename, client_checksum):
self.logger.info(f"Skipping upload of identical file: {filename}")
return filename

Example:
>>> client.put_local_file("/path/to/local/file.txt")
"""
absolute = Path(filepath).resolve()
with OpendalAdapter(client=self, operator=Operator("fs", root="/"), path=str(absolute), mode="rb") as handle:
return self.call("put_file", absolute.name, handle)
self.logger.info(f"checksum: {client_checksum}")
with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle:
return self.call("put_file", filename, handle, client_checksum)

def delete_file(self, filename: str):
"""
Expand Down Expand Up @@ -108,3 +103,10 @@ def get_port(self) -> int:
int: The port number (default is 69)
"""
return self.call("get_port")

def _compute_checksum(self, operator: Operator, path: str) -> str:
hasher = hashlib.sha256()
with operator.open(path, "rb") as f:
while chunk := f.read(CHUNK_SIZE):
hasher.update(chunk)
return hasher.hexdigest()
60 changes: 37 additions & 23 deletions packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import hashlib
import os
import socket
import threading
Expand All @@ -10,33 +11,28 @@

from jumpstarter_driver_tftp.server import TftpServer

from . import CHUNK_SIZE
from jumpstarter.driver import Driver, export


class TftpError(Exception):
"""Base exception for TFTP server errors"""

pass


class ServerNotRunning(TftpError):
"""Server is not running"""

pass


class FileNotFound(TftpError):
"""File not found"""

pass


@dataclass(kw_only=True)
class Tftp(Driver):
"""TFTP Server driver for Jumpstarter"""

root_dir: str = "/var/lib/tftpboot"
host: str = field(default=None)
host: str = field(default='')
port: int = 69
server: Optional["TftpServer"] = field(init=False, default=None)
server_thread: Optional[threading.Thread] = field(init=False, default=None)
Expand All @@ -49,7 +45,7 @@ def __post_init__(self):
super().__post_init__()

os.makedirs(self.root_dir, exist_ok=True)
if self.host is None:
if self.host == '':
self.host = self.get_default_ip()

def get_default_ip(self):
Expand All @@ -71,10 +67,7 @@ def _start_server(self):
asyncio.set_event_loop(self._loop)
self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir)
try:
# Signal that the loop is ready
self._loop_ready.set()

# Run the server until shutdown is requested
self._loop.run_until_complete(self._run_server())
except Exception as e:
self.logger.error(f"Error running TFTP server: {e}")
Expand All @@ -84,7 +77,6 @@ def _start_server(self):
self._loop.close()
except Exception as e:
self.logger.error(f"Error during event loop cleanup: {e}")

self._loop = None
self.logger.info("TFTP server thread completed")

Expand All @@ -109,11 +101,9 @@ def start(self):
self.logger.warning("TFTP server is already running")
return

# Clear any previous shutdown state
self._shutdown_event.clear()
self._loop_ready.clear()

# Start the server thread
self.server_thread = threading.Thread(target=self._start_server, daemon=True)
self.server_thread.start()

Expand All @@ -131,7 +121,6 @@ def stop(self):
return

self.logger.info("Initiating TFTP server shutdown")

self._shutdown_event.set()
self.server_thread.join(timeout=10)
if self.server_thread.is_alive():
Expand All @@ -145,11 +134,10 @@ def list_files(self) -> list[str]:
return os.listdir(self.root_dir)

@export
async def put_file(self, filename: str, src_stream):
"""Handle file upload using streaming"""
try:
file_path = os.path.join(self.root_dir, filename)
async def put_file(self, filename: str, src_stream, client_checksum: str):
file_path = os.path.join(self.root_dir, filename)

try:
if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()):
raise TftpError("Invalid target path")

Expand All @@ -159,19 +147,38 @@ async def put_file(self, filename: str, src_stream):
await dst.send(chunk)

return filename

except Exception as e:
raise TftpError(f"Failed to upload file: {str(e)}") from e

@export
def delete_file(self, filename: str):
file_path = os.path.join(self.root_dir, filename)

if not os.path.exists(file_path):
raise FileNotFound(f"File {filename} not found")

try:
os.remove(os.path.join(self.root_dir, filename))
except FileNotFoundError as err:
raise FileNotFound(f"File {filename} not found") from err
os.remove(file_path)
return filename
except Exception as e:
raise TftpError(f"Failed to delete {filename}") from e

@export
def check_file_checksum(self, filename: str, client_checksum: str) -> bool:
file_path = os.path.join(self.root_dir, filename)
self.logger.debug(f"checking checksum for file: {filename}")
self.logger.debug(f"file path: {file_path}")

if not os.path.exists(file_path):
self.logger.debug(f"File {filename} does not exist")
return False

current_checksum = self._compute_checksum(file_path)
self.logger.debug(f"Computed checksum: {current_checksum}")
self.logger.debug(f"Client checksum: {client_checksum}")

return current_checksum == client_checksum

@export
def get_host(self) -> str:
return self.host
Expand All @@ -184,3 +191,10 @@ def close(self):
if self.server_thread is not None:
self.stop()
super().close()

def _compute_checksum(self, path: str) -> str:
hasher = hashlib.sha256()
with open(path, "rb") as f:
while chunk := f.read(CHUNK_SIZE):
hasher.update(chunk)
return hasher.hexdigest()
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import os
import tempfile
from pathlib import Path
Expand All @@ -20,18 +21,17 @@ def temp_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir


@pytest.fixture
def server(temp_dir):
server = Tftp(root_dir=temp_dir, host="127.0.0.1")
yield server
server.close()


@pytest.mark.anyio
async def test_tftp_file_operations(server):
filename = "test.txt"
test_data = b"Hello"
client_checksum = hashlib.sha256(test_data).hexdigest()

send_stream, receive_stream = create_memory_object_stream(max_buffer_size=10)

Expand All @@ -46,8 +46,7 @@ async def send_data():

async with anyio.create_task_group() as tg:
tg.start_soon(send_data)

await server.put_file(filename, resource_handle)
await server.put_file(filename, resource_handle, client_checksum)

files = server.list_files()
assert filename in files
Expand All @@ -61,20 +60,48 @@ async def send_data():
with pytest.raises(FileNotFound):
server.delete_file("nonexistent.txt")


def test_tftp_host_config(temp_dir):
custom_host = "192.168.1.1"
server = Tftp(root_dir=temp_dir, host=custom_host)
assert server.get_host() == custom_host


def test_tftp_root_directory_creation(temp_dir):
new_dir = os.path.join(temp_dir, "new_tftp_root")
server = Tftp(root_dir=new_dir)
assert os.path.exists(new_dir)
server.close()

@pytest.mark.anyio
async def test_tftp_detect_corrupted_file(server):
filename = "corrupted.txt"
original_data = b"Original Data"
client_checksum = hashlib.sha256(original_data).hexdigest()

await _upload_file(server, filename, original_data)

assert server.check_file_checksum(filename, client_checksum)

file_path = Path(server.root_dir, filename)
file_path.write_bytes(b"corrupted Data")

assert not server.check_file_checksum(filename, client_checksum)

@pytest.fixture
def anyio_backend():
return "asyncio"

async def _upload_file(server, filename: str, data: bytes) -> str:
send_stream, receive_stream = create_memory_object_stream()
resource_uuid = uuid4()
server.resources[resource_uuid] = receive_stream
resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json")

async def send_data():
await send_stream.send(data)
await send_stream.aclose()

async with anyio.create_task_group() as tg:
tg.start_soon(send_data)
await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest())

return hashlib.sha256(data).hexdigest()
Loading

0 comments on commit ee4720b

Please sign in to comment.