diff --git a/library/spdm_requester_lib/libspdm_req_get_capabilities.c b/library/spdm_requester_lib/libspdm_req_get_capabilities.c index 95e05cd14bc..abff87e3ca1 100644 --- a/library/spdm_requester_lib/libspdm_req_get_capabilities.c +++ b/library/spdm_requester_lib/libspdm_req_get_capabilities.c @@ -79,7 +79,11 @@ static bool validate_responder_capability(uint32_t capabilities_flag, uint8_t ve } /* Checks that originate from key exchange capabilities. */ - if ((key_ex_cap == 0) && (psk_cap == 0)) { + if ((key_ex_cap == 1) || (psk_cap != 0)) { + if ((mac_cap == 0) && (encrypt_cap == 0)) { + return false; + } + } else { if ((mac_cap == 1) || (encrypt_cap == 1) || (handshake_in_the_clear_cap == 1) || (hbeat_cap == 1) || (key_upd_cap == 1)) { return false; diff --git a/library/spdm_responder_lib/libspdm_rsp_capabilities.c b/library/spdm_responder_lib/libspdm_rsp_capabilities.c index 3d7986d18a2..2edbf0ea3fa 100644 --- a/library/spdm_responder_lib/libspdm_rsp_capabilities.c +++ b/library/spdm_responder_lib/libspdm_rsp_capabilities.c @@ -71,7 +71,11 @@ static bool libspdm_check_request_flag_compatibility(uint32_t capabilities_flag, } /* Checks that originate from key exchange capabilities. */ - if ((key_ex_cap == 0) && (psk_cap == 0)) { + if ((key_ex_cap == 1) || (psk_cap != 0)) { + if ((mac_cap == 0) && (encrypt_cap == 0)) { + return false; + } + } else { if ((mac_cap == 1) || (encrypt_cap == 1) || (handshake_in_the_clear_cap == 1) || (hbeat_cap == 1) || (key_upd_cap == 1)) { return false; diff --git a/unit_test/test_spdm_requester/error_test/get_capabilities_err.c b/unit_test/test_spdm_requester/error_test/get_capabilities_err.c index 46a94ecee56..842872aca02 100644 --- a/unit_test/test_spdm_requester/error_test/get_capabilities_err.c +++ b/unit_test/test_spdm_requester/error_test/get_capabilities_err.c @@ -592,7 +592,33 @@ static libspdm_return_t libspdm_requester_get_capabilities_test_receive_message( } return LIBSPDM_STATUS_SUCCESS; - case 0x14: + case 0x14: { + spdm_capabilities_response_t *spdm_response; + size_t spdm_response_size; + size_t transport_header_size; + + spdm_response_size = sizeof(spdm_capabilities_response_t); + transport_header_size = LIBSPDM_TEST_TRANSPORT_HEADER_SIZE; + spdm_response = (void *)((uint8_t *)*response + transport_header_size); + + libspdm_zero_mem(spdm_response, spdm_response_size); + spdm_response->header.spdm_version = SPDM_MESSAGE_VERSION_11; + spdm_response->header.request_response_code = SPDM_CAPABILITIES; + spdm_response->header.param1 = 0; + spdm_response->header.param2 = 0; + spdm_response->ct_exponent = 0; + spdm_response->flags = + LIBSPDM_DEFAULT_CAPABILITY_RESPONSE_FLAG_VERSION_11 & + (0xFFFFFFFF ^ + (SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_ENCRYPT_CAP | + SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_MAC_CAP | + SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_HANDSHAKE_IN_THE_CLEAR_CAP)); + + libspdm_transport_test_encode_message(spdm_context, NULL, false, + false, spdm_response_size, + spdm_response, + response_size, response); + } return LIBSPDM_STATUS_SUCCESS; case 0x15: @@ -1394,6 +1420,22 @@ static void libspdm_test_requester_get_capabilities_err_case19(void **state) static void libspdm_test_requester_get_capabilities_err_case20(void **state) { + libspdm_return_t status; + libspdm_test_context_t *spdm_test_context; + libspdm_context_t *spdm_context; + + spdm_test_context = *state; + spdm_context = spdm_test_context->spdm_context; + spdm_test_context->case_id = 0x14; + spdm_context->connection_info.version = SPDM_MESSAGE_VERSION_11 << + SPDM_VERSION_NUMBER_SHIFT_BIT; + spdm_context->connection_info.connection_state = LIBSPDM_CONNECTION_STATE_AFTER_VERSION; + libspdm_reset_message_a(spdm_context); + + spdm_context->local_context.capability.ct_exponent = 0; + spdm_context->local_context.capability.flags = LIBSPDM_DEFAULT_CAPABILITY_FLAG_VERSION_11; + status = libspdm_get_capabilities(spdm_context); + assert_int_equal(status, LIBSPDM_STATUS_INVALID_MSG_FIELD); } static void libspdm_test_requester_get_capabilities_err_case21(void **state) diff --git a/unit_test/test_spdm_responder/capabilities.c b/unit_test/test_spdm_responder/capabilities.c index d0129d44ac6..0948e9044cd 100644 --- a/unit_test/test_spdm_responder/capabilities.c +++ b/unit_test/test_spdm_responder/capabilities.c @@ -837,10 +837,64 @@ void libspdm_test_responder_capabilities_case15(void **state) void libspdm_test_responder_capabilities_case16(void **state) { + libspdm_return_t status; + libspdm_test_context_t *spdm_test_context; + libspdm_context_t *spdm_context; + size_t response_size; + uint8_t response[LIBSPDM_MAX_SPDM_MSG_SIZE]; + spdm_capabilities_response_t *spdm_response; + + spdm_test_context = *state; + spdm_context = spdm_test_context->spdm_context; + spdm_test_context->case_id = 0x10; + spdm_context->connection_info.connection_state = + LIBSPDM_CONNECTION_STATE_AFTER_VERSION; + + response_size = sizeof(response); + status = libspdm_get_response_capabilities( + spdm_context, m_libspdm_get_capabilities_request12_size, + &m_libspdm_get_capabilities_request12, &response_size, response); + assert_int_equal(status, LIBSPDM_STATUS_SUCCESS); + assert_int_equal(response_size, sizeof(spdm_error_response_t)); + spdm_response = (void *)response; + assert_int_equal(m_libspdm_get_capabilities_request12.header.spdm_version, + spdm_response->header.spdm_version); + assert_int_equal(spdm_response->header.request_response_code, + SPDM_ERROR); + assert_int_equal(spdm_response->header.param1, + SPDM_ERROR_CODE_INVALID_REQUEST); + assert_int_equal(spdm_response->header.param2, 0); } void libspdm_test_responder_capabilities_case17(void **state) { + libspdm_return_t status; + libspdm_test_context_t *spdm_test_context; + libspdm_context_t *spdm_context; + size_t response_size; + uint8_t response[LIBSPDM_MAX_SPDM_MSG_SIZE]; + spdm_capabilities_response_t *spdm_response; + + spdm_test_context = *state; + spdm_context = spdm_test_context->spdm_context; + spdm_test_context->case_id = 0x11; + spdm_context->connection_info.connection_state = + LIBSPDM_CONNECTION_STATE_AFTER_VERSION; + + response_size = sizeof(response); + status = libspdm_get_response_capabilities( + spdm_context, m_libspdm_get_capabilities_request13_size, + &m_libspdm_get_capabilities_request13, &response_size, response); + assert_int_equal(status, LIBSPDM_STATUS_SUCCESS); + assert_int_equal(response_size, sizeof(spdm_error_response_t)); + spdm_response = (void *)response; + assert_int_equal(m_libspdm_get_capabilities_request13.header.spdm_version, + spdm_response->header.spdm_version); + assert_int_equal(spdm_response->header.request_response_code, + SPDM_ERROR); + assert_int_equal(spdm_response->header.param1, + SPDM_ERROR_CODE_INVALID_REQUEST); + assert_int_equal(spdm_response->header.param2, 0); } void libspdm_test_responder_capabilities_case18(void **state) @@ -1173,9 +1227,9 @@ int libspdm_responder_capabilities_test_main(void) cmocka_unit_test(libspdm_test_responder_capabilities_case14), /* mac_cap set and key_ex_cap and psk_cap cleared (mac_cap demands key_ex_cap or psk_cap to be set)*/ cmocka_unit_test(libspdm_test_responder_capabilities_case15), - /* Open test case */ + /* key_ex_cap set and encrypt_cap and mac_cap cleared (key_ex_cap demands encrypt_cap or mac_cap to be set)*/ cmocka_unit_test(libspdm_test_responder_capabilities_case16), - /* Open test case */ + /* psk_cap set and encrypt_cap and mac_cap cleared (psk_cap demands encrypt_cap or mac_cap to be set)*/ cmocka_unit_test(libspdm_test_responder_capabilities_case17), /* encap_cap cleared and MUT_AUTH set (MUT_AUTH demands encap_cap to be set)*/ cmocka_unit_test(libspdm_test_responder_capabilities_case18),