From d636c3fa0d22168f412193a34ec5a466b6742c9e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 29 Jan 2025 21:41:36 +0100 Subject: [PATCH] tests(python): improve thp device tests [no changelog] --- tests/device_tests/thp/test_thp.py | 245 ++++++++++++----------------- 1 file changed, 97 insertions(+), 148 deletions(-) diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index 4054202a450..3a41c0c55e5 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -5,6 +5,7 @@ import pytest import typing_extensions as tx +from trezorlib import protobuf from trezorlib.client import ProtocolV2 from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ( @@ -12,7 +13,9 @@ ButtonRequest, ThpCodeEntryChallenge, ThpCodeEntryCommitment, + ThpCodeEntryCpaceHostTag, ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, ThpEndRequest, ThpEndResponse, ThpPairingMethod, @@ -29,11 +32,26 @@ if t.TYPE_CHECKING: P = tx.ParamSpec("P") +MT = t.TypeVar("MT", bound=protobuf.MessageType) + pytestmark = [pytest.mark.protocol("protocol_v2")] +protocol: ProtocolV2 + + +def _prepare_protocol(client: Client): + global protocol + protocol = client.protocol + protocol.sync_bit_send = 0 + protocol.sync_bit_receive = 0 + + def test_allocate_channel(client: Client) -> None: - protocol: ProtocolV2 = client.protocol + global protocol + _prepare_protocol(client) + + # protocol: ProtocolV2 = client.protocol nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F" # Use valid nonce @@ -50,10 +68,10 @@ def test_allocate_channel(client: Client) -> None: def test_handshake(client: Client) -> None: - protocol: ProtocolV2 = client.protocol + global protocol + _prepare_protocol(client) + # protocol: ProtocolV2 = client.protocol - protocol.sync_bit_send = 0 - protocol.sync_bit_receive = 0 host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) @@ -88,10 +106,27 @@ def test_handshake(client: Client) -> None: assert noise_tag is not None +def _send_message( + message: MT, + session_id: int = MANAGEMENT_SESSION_ID, +): + global protocol + message_type, message_data = protocol.mapping.encode(message) + protocol._encrypt_and_write(session_id, message_type, message_data) + protocol._read_ack() + + +def _read_message(message_type: type[MT]) -> MT: + global protocol + _, msg_type, msg_data = protocol.read_and_decrypt() + msg = protocol.mapping.decode(msg_type, msg_data) + assert isinstance(msg, message_type) + return msg + + def test_pairing_qr_code(client: Client) -> None: - protocol: ProtocolV2 = client.protocol - protocol.sync_bit_send = 0 - protocol.sync_bit_receive = 0 + global protocol + _prepare_protocol(client) # Generate ephemeral keys host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) @@ -101,96 +136,54 @@ def test_pairing_qr_code(client: Client) -> None: protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - # Send StartPairingReqest message - message = ThpPairingRequest() - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - - # Read ACK - protocol._read_ack() - - # Read button request - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ButtonRequest) + _send_message(ThpPairingRequest()) - # Send button ACK - message = ButtonAck() - message_type, message_data = protocol.mapping.encode(message) + _read_message(ButtonRequest) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - protocol._read_ack() + _send_message(ButtonAck()) client.debug.press_yes() - # Read PairingRequestApproved - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - - assert isinstance(maaa, ThpPairingRequestApproved) - - message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) - message_type, message_data = protocol.mapping.encode(message) + _read_message(ThpPairingRequestApproved) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - protocol._read_ack() + _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)) - # Read ThpPairingPreparationsFinished - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpPairingPreparationsFinished) + _read_message(ThpPairingPreparationsFinished) # QR Code shown - # Read button request - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ButtonRequest) - - # Send button ACK - message = ButtonAck() - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - protocol._read_ack() + _read_message(ButtonRequest) + _send_message(ButtonAck()) + # Read code from "Trezor's display" using debuglink state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big")) + code = state.thp_pairing_code_qr_code + # Compute tag for response sha_ctx = hashlib.sha256(protocol.handshake_hash) sha_ctx.update(state.thp_pairing_code_qr_code) tag = sha_ctx.digest() - message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag)) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - - protocol._read_ack() - - # Read ThpQrCodeSecret - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpQrCodeSecret) + _send_message(ThpQrCodeTag(tag=tag)) - message = ThpEndRequest() - message_type, message_data = protocol.mapping.encode(message) + secret_msg = _read_message(ThpQrCodeSecret) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - protocol._read_ack() + # Check that the `code` was derived from the revealed secret + sha_ctx = hashlib.sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + computed_code = sha_ctx.digest()[:16] + assert code == computed_code - # Read ThpEndResponse - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpEndResponse) + _send_message(ThpEndRequest()) + _read_message(ThpEndResponse) protocol._has_valid_channel = True -@pytest.mark.skip("Cpace is not implemented yet") +# @pytest.mark.skip("Cpace is not implemented yet") def test_pairing_code_entry(client: Client) -> None: - protocol: ProtocolV2 = client.protocol - protocol.sync_bit_send = 0 - protocol.sync_bit_receive = 0 + global protocol + _prepare_protocol(client) # Generate ephemeral keys host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) @@ -200,101 +193,57 @@ def test_pairing_code_entry(client: Client) -> None: protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - # Send StartPairingReqest message - message = ThpPairingRequest() - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - - # Read ACK - protocol._read_ack() + _send_message(ThpPairingRequest()) - # Read button request - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ButtonRequest) + _read_message(ButtonRequest) - # Send button ACK - message = ButtonAck() - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - protocol._read_ack() + _send_message(ButtonAck()) client.debug.press_yes() - # Read PairingRequestApproved - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - - assert isinstance(maaa, ThpPairingRequestApproved) + _read_message(ThpPairingRequestApproved) - message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - protocol._read_ack() + _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)) - # Read ThpCodeEntryCommitment - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpCodeEntryCommitment) + commitment_msg = _read_message(ThpCodeEntryCommitment) + commitment = commitment_msg.commitment challenge = b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xAA\xBB\xCC\xDD\xEE\xFF" - message = ThpCodeEntryChallenge(challenge=challenge) - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - protocol._read_ack() - - # Read ThpCodeEntryCpaceTrezor - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpCodeEntryCpaceTrezor) + _send_message(ThpCodeEntryChallenge(challenge=challenge)) - _ = maaa.cpace_trezor_public_key + cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor) + cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key # Code Entry code shown - # Read button request - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ButtonRequest) - - # Send button ACK - message = ButtonAck() - message_type, message_data = protocol.mapping.encode(message) - - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - protocol._read_ack() + _read_message(ButtonRequest) + _send_message(ButtonAck()) state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big")) + code = state.thp_pairing_code_entry_code - sha_ctx = hashlib.sha256(protocol.handshake_hash) - sha_ctx.update(state.thp_pairing_code_entry_code) + # TODO fix missing CPACE + cpace_shared_secret = b"\x01" + sha_ctx = hashlib.sha256(cpace_shared_secret) tag = sha_ctx.digest() - message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag)) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + cpace_host_public_key = cpace_trezor_public_key - protocol._read_ack() - - # Read ThpQrCodeSecret - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpQrCodeSecret) + _send_message( + ThpCodeEntryCpaceHostTag( + cpace_host_public_key=cpace_host_public_key, + tag=tag, + ) + ) - message = ThpEndRequest() - message_type, message_data = protocol.mapping.encode(message) + secret_msg = _read_message(ThpCodeEntrySecret) - protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - protocol._read_ack() + # Check `commitment` and `code` + sha_ctx = hashlib.sha256(secret_msg.secret) + computed_commitment = sha_ctx.digest() + assert commitment == computed_commitment + assert code == b"" # TODO implement - # Read ThpEndResponse - _, msg_type, msg_data = protocol.read_and_decrypt() - maaa = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, ThpEndResponse) + _send_message(ThpEndRequest()) + _read_message(ThpEndResponse) protocol._has_valid_channel = True