Skip to content

Commit

Permalink
tests(python): improve thp device tests
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
M1nd3r committed Jan 29, 2025
1 parent d8506f8 commit d636c3f
Showing 1 changed file with 97 additions and 148 deletions.
245 changes: 97 additions & 148 deletions tests/device_tests/thp/test_thp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import pytest
import typing_extensions as tx

from trezorlib import protobuf
from trezorlib.client import ProtocolV2
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import (
ButtonAck,
ButtonRequest,
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
ThpCodeEntryCpaceHostTag,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpEndRequest,
ThpEndResponse,
ThpPairingMethod,
Expand All @@ -29,11 +32,26 @@
if t.TYPE_CHECKING:
P = tx.ParamSpec("P")

MT = t.TypeVar("MT", bound=protobuf.MessageType)

pytestmark = [pytest.mark.protocol("protocol_v2")]


protocol: ProtocolV2


def _prepare_protocol(client: Client):
global protocol
protocol = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0


def test_allocate_channel(client: Client) -> None:
protocol: ProtocolV2 = client.protocol
global protocol
_prepare_protocol(client)

# protocol: ProtocolV2 = client.protocol
nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F"

# Use valid nonce
Expand All @@ -50,10 +68,10 @@ def test_allocate_channel(client: Client) -> None:


def test_handshake(client: Client) -> None:
protocol: ProtocolV2 = client.protocol
global protocol
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol

protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)

Expand Down Expand Up @@ -88,10 +106,27 @@ def test_handshake(client: Client) -> None:
assert noise_tag is not None


def _send_message(
message: MT,
session_id: int = MANAGEMENT_SESSION_ID,
):
global protocol
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(session_id, message_type, message_data)
protocol._read_ack()


def _read_message(message_type: type[MT]) -> MT:
global protocol
_, msg_type, msg_data = protocol.read_and_decrypt()
msg = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg


def test_pairing_qr_code(client: Client) -> None:
protocol: ProtocolV2 = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
global protocol
_prepare_protocol(client)

# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
Expand All @@ -101,96 +136,54 @@ def test_pairing_qr_code(client: Client) -> None:

protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)

# Send StartPairingReqest message
message = ThpPairingRequest()
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)

# Read ACK
protocol._read_ack()

# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
_send_message(ThpPairingRequest())

# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)
_read_message(ButtonRequest)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
_send_message(ButtonAck())

client.debug.press_yes()

# Read PairingRequestApproved
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)

assert isinstance(maaa, ThpPairingRequestApproved)

message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
message_type, message_data = protocol.mapping.encode(message)
_read_message(ThpPairingRequestApproved)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode))

# Read ThpPairingPreparationsFinished
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpPairingPreparationsFinished)
_read_message(ThpPairingPreparationsFinished)

# QR Code shown
# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)

# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
_read_message(ButtonRequest)
_send_message(ButtonAck())

# Read code from "Trezor's display" using debuglink
state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big"))
code = state.thp_pairing_code_qr_code

# Compute tag for response
sha_ctx = hashlib.sha256(protocol.handshake_hash)
sha_ctx.update(state.thp_pairing_code_qr_code)
tag = sha_ctx.digest()

message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag))
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)

protocol._read_ack()

# Read ThpQrCodeSecret
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpQrCodeSecret)
_send_message(ThpQrCodeTag(tag=tag))

message = ThpEndRequest()
message_type, message_data = protocol.mapping.encode(message)
secret_msg = _read_message(ThpQrCodeSecret)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
# Check that the `code` was derived from the revealed secret
sha_ctx = hashlib.sha256(ThpPairingMethod.QrCode.to_bytes(1, "big"))
sha_ctx.update(protocol.handshake_hash)
sha_ctx.update(secret_msg.secret)
computed_code = sha_ctx.digest()[:16]
assert code == computed_code

# Read ThpEndResponse
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpEndResponse)
_send_message(ThpEndRequest())
_read_message(ThpEndResponse)

protocol._has_valid_channel = True


@pytest.mark.skip("Cpace is not implemented yet")
# @pytest.mark.skip("Cpace is not implemented yet")
def test_pairing_code_entry(client: Client) -> None:
protocol: ProtocolV2 = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
global protocol
_prepare_protocol(client)

# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
Expand All @@ -200,101 +193,57 @@ def test_pairing_code_entry(client: Client) -> None:

protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)

# Send StartPairingReqest message
message = ThpPairingRequest()
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)

# Read ACK
protocol._read_ack()
_send_message(ThpPairingRequest())

# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
_read_message(ButtonRequest)

# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
_send_message(ButtonAck())

client.debug.press_yes()

# Read PairingRequestApproved
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)

assert isinstance(maaa, ThpPairingRequestApproved)
_read_message(ThpPairingRequestApproved)

message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry))

# Read ThpCodeEntryCommitment
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpCodeEntryCommitment)
commitment_msg = _read_message(ThpCodeEntryCommitment)
commitment = commitment_msg.commitment

challenge = b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xAA\xBB\xCC\xDD\xEE\xFF"
message = ThpCodeEntryChallenge(challenge=challenge)
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()

# Read ThpCodeEntryCpaceTrezor
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpCodeEntryCpaceTrezor)
_send_message(ThpCodeEntryChallenge(challenge=challenge))

_ = maaa.cpace_trezor_public_key
cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor)
cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key

# Code Entry code shown
# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)

# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
_read_message(ButtonRequest)
_send_message(ButtonAck())

state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big"))
code = state.thp_pairing_code_entry_code

sha_ctx = hashlib.sha256(protocol.handshake_hash)
sha_ctx.update(state.thp_pairing_code_entry_code)
# TODO fix missing CPACE
cpace_shared_secret = b"\x01"
sha_ctx = hashlib.sha256(cpace_shared_secret)
tag = sha_ctx.digest()

message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag))
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
cpace_host_public_key = cpace_trezor_public_key

protocol._read_ack()

# Read ThpQrCodeSecret
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpQrCodeSecret)
_send_message(
ThpCodeEntryCpaceHostTag(
cpace_host_public_key=cpace_host_public_key,
tag=tag,
)
)

message = ThpEndRequest()
message_type, message_data = protocol.mapping.encode(message)
secret_msg = _read_message(ThpCodeEntrySecret)

protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
# Check `commitment` and `code`
sha_ctx = hashlib.sha256(secret_msg.secret)
computed_commitment = sha_ctx.digest()
assert commitment == computed_commitment
assert code == b"" # TODO implement

# Read ThpEndResponse
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpEndResponse)
_send_message(ThpEndRequest())
_read_message(ThpEndResponse)

protocol._has_valid_channel = True

0 comments on commit d636c3f

Please sign in to comment.