Skip to content

Commit

Permalink
Added enhanced authentication support to MQTTClient.
Browse files Browse the repository at this point in the history
Signed-off-by: Diego Dassie <[email protected]>
  • Loading branch information
ddassie-texa committed May 30, 2023
1 parent 39a8672 commit cbe0350
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 136 deletions.
3 changes: 2 additions & 1 deletion src/MQTTAsync.c
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,8 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
property.value.data.data = authData.authDataOut.data;
property.value.data.len = authData.authDataOut.len;
rc = MQTTProperties_add(m->connectProps, &property);
free(authData.authDataOut.data);
if(authData.authDataOut.data)
free(authData.authDataOut.data);
if (rc)
goto exit;
}
Expand Down
3 changes: 2 additions & 1 deletion src/MQTTAsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ typedef struct
* This is a callback function which will allow the client application to update the
* connection data.
* @param data The connection data which can be modified by the application.
* @return Return a zero or positive value to indicate sucess, a negative value on failure.
* @return a negative value to indicate not-authorized, a value of 0 to indicate success,
* a positive value to indicate continue.
*/
typedef int MQTTAsync_authHandle(void* context, MQTTAsync_authHandleData* data);

Expand Down
143 changes: 50 additions & 93 deletions src/MQTTAsyncUtils.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ static int cmdMessageIDCompare(void* a, void* b);
static void MQTTAsync_retry(void);
static MQTTPacket* MQTTAsync_cycle(SOCKET* sock, unsigned long timeout, int* rc);
static int MQTTAsync_connecting(MQTTAsyncs* m);
static enum MQTTReasonCodes MQTTAsync_processAuth(MQTTAsync_authHandle *func, void *context,
MQTTAsync_authHandleData *data);
static enum MQTTReasonCodes MQTTAsync_processAuth(MQTTAsyncs *m, int rc,
MQTTProperties *props, MQTTAsync_authHandleData *out);
static int MQTTAsync_verifyAuthMethod(const char* authMethod,
const char* data, int dataLen);

Expand Down Expand Up @@ -2140,39 +2140,10 @@ thread_return_type WINAPI MQTTAsync_receiveThread(void* n)
if (rc == MQTTASYNC_SUCCESS && m->c->authMethod)
{
MQTTAsync_authHandleData authHandleData = MQTTAsync_authHandleData_initializer;
MQTTProperty *authMethodProp = NULL;
char *authMethod = NULL;
int authMethodLen = 0;
MQTTProperty *authData = NULL;

authMethodProp = MQTTProperties_getProperty(&connack->properties,
MQTTPROPERTY_CODE_AUTHENTICATION_METHOD);
if (authMethodProp)
{
authMethod = authMethodProp->value.data.data;
authMethodLen = authMethodProp->value.data.len;
}

if ((connack->rc == 0 && authMethodProp == NULL) ||
MQTTAsync_verifyAuthMethod(m->c->authMethod, authMethod,
authMethodLen) == 0)
{
authData = MQTTProperties_getProperty(&connack->properties,
MQTTPROPERTY_CODE_AUTHENTICATION_DATA);

authHandleData.reasonCode = connack->rc;
if (authData && authMethod)
{
authHandleData.authDataIn.data = authData->value.data.data;
authHandleData.authDataIn.len = authData->value.data.len;
}

rc = MQTTAsync_processAuth(m->auth_handle,
m->auth_handle_context,
&authHandleData);
}
else
rc = MQTTREASONCODE_BAD_AUTHENTICATION_METHOD;
rc = MQTTAsync_processAuth(m, connack->rc, &connack->properties,
&authHandleData);
if (authHandleData.authDataOut.data)
free(authHandleData.authDataOut.data);
}

if (rc == MQTTASYNC_SUCCESS)
Expand Down Expand Up @@ -2412,50 +2383,16 @@ thread_return_type WINAPI MQTTAsync_receiveThread(void* n)
{
Auth *auth = (Auth *)pack;
enum MQTTReasonCodes authrc = MQTTREASONCODE_SUCCESS;
MQTTProperty *authMethodProp = NULL;
MQTTProperty *authData = NULL;
char* authMethod = NULL;
int authMethodLen = 0;
MQTTAsync_authHandleData authHandleData = MQTTAsync_authHandleData_initializer;

if (m->c->authMethod)
{
authMethodProp = MQTTProperties_getProperty(&auth->properties,
MQTTPROPERTY_CODE_AUTHENTICATION_METHOD);
if (authMethodProp)
{
authMethod = authMethodProp->value.data.data;
authMethodLen = authMethodProp->value.data.len;
}

if ((auth->rc == 0 && authMethodProp == NULL) ||
MQTTAsync_verifyAuthMethod(m->c->authMethod, authMethod,
authMethodLen) == 0)
{
authData = MQTTProperties_getProperty(&auth->properties,
MQTTPROPERTY_CODE_AUTHENTICATION_DATA);

authHandleData.reasonCode = auth->rc;
if (authData && authMethod)
{
authHandleData.authDataIn.data = authData->value.data.data;
authHandleData.authDataIn.len = authData->value.data.len;
}

authrc = MQTTAsync_processAuth(m->auth_handle,
m->auth_handle_context,
&authHandleData);
}
else
authrc = MQTTREASONCODE_BAD_AUTHENTICATION_METHOD;
}
else
authrc = MQTTREASONCODE_PROTOCOL_ERROR;
authrc = MQTTAsync_processAuth(m, auth->rc, &auth->properties,
&authHandleData);

rc = MQTTProtocol_handleAuth(pack, m->c->net.socket, authrc,
authHandleData.authDataOut.data,
authHandleData.authDataOut.len);
free(authHandleData.authDataOut.data);
if (authHandleData.authDataOut.data)
free(authHandleData.authDataOut.data);
if (authrc != MQTTREASONCODE_SUCCESS &&
authrc != MQTTREASONCODE_CONTINUE_AUTHENTICATION)
nextOrClose(m, authrc, "Authentication failed");
Expand Down Expand Up @@ -3320,34 +3257,54 @@ int MQTTAsync_getNoBufferedMessages(MQTTAsyncs* m)
}


enum MQTTReasonCodes MQTTAsync_processAuth(MQTTAsync_authHandle *func, void *context,
MQTTAsync_authHandleData *data)
enum MQTTReasonCodes MQTTAsync_processAuth(MQTTAsyncs *m, int rc, MQTTProperties *props,
MQTTAsync_authHandleData *out)
{
int rc;
MQTTProperty *authMethodProp = NULL;
char *authMethodIn = NULL;
int authMethodInLen = 0;
MQTTProperty *authData = NULL;

if (func == NULL)
return MQTTREASONCODE_NOT_AUTHORIZED;

rc = (*(func))(context, data);
if (rc < 0)
return MQTTREASONCODE_NOT_AUTHORIZED;
if (!m->c->authMethod)
{
return MQTTREASONCODE_PROTOCOL_ERROR;
}

if (rc > 0)
return MQTTREASONCODE_CONTINUE_AUTHENTICATION;
authMethodProp = MQTTProperties_getProperty(props,
MQTTPROPERTY_CODE_AUTHENTICATION_METHOD);
if (authMethodProp)
{
authMethodIn = authMethodProp->value.data.data;
authMethodInLen = authMethodProp->value.data.len;
}

return MQTTREASONCODE_SUCCESS;
}
if ((rc == 0 && authMethodProp == NULL) ||
(m->c->authMethod && authMethodIn && authMethodInLen > 0 &&
strlen(m->c->authMethod) == authMethodInLen &&
memcmp(m->c->authMethod, authMethodIn, authMethodInLen) == 0))
{
if (m->auth_handle == NULL)
return MQTTREASONCODE_NOT_AUTHORIZED;

authData = MQTTProperties_getProperty(props,
MQTTPROPERTY_CODE_AUTHENTICATION_DATA);

int MQTTAsync_verifyAuthMethod(const char *authMethod, const char *data, int dataLen)
{
if (authMethod && data && dataLen > 0)
{
if (strlen(authMethod) == dataLen && memcmp(authMethod, data, dataLen) == 0)
out->reasonCode = rc;
if (authData && authMethodIn)
{
return 0;
out->authDataIn.data = authData->value.data.data;
out->authDataIn.len = authData->value.data.len;
}

rc = (*(m->auth_handle))(m->auth_handle_context, out);
if (rc < 0)
return MQTTREASONCODE_NOT_AUTHORIZED;

if (rc > 0)
return MQTTREASONCODE_CONTINUE_AUTHENTICATION;

return MQTTREASONCODE_SUCCESS;
}

return -1;
return MQTTREASONCODE_BAD_AUTHENTICATION_METHOD;
}
111 changes: 80 additions & 31 deletions src/MQTTClient.c
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
static void MQTTProtocol_checkPendingWrites(void);
static void MQTTClient_writeComplete(SOCKET socket, int rc);
static void MQTTClient_writeContinue(SOCKET socket);
static enum MQTTReasonCodes MQTTClient_processAuth(MQTTClients *m, int rc,
MQTTProperties *props, MQTTClient_handleAuthData *out);


int MQTTClient_createWithOptions(MQTTClient* handle, const char* serverURI, const char* clientId,
Expand Down Expand Up @@ -801,25 +803,6 @@ int MQTTClient_setHandleAuth(MQTTClient handle, void* context, MQTTClient_handle
}


/**
* Wrapper function to call authHandle on a separate thread. A separate thread is needed to allow the
* disconnected function to make API calls (e.g. MQTTClient_auth)
* @param context a pointer to the relevant client
* @return thread_return_type standard thread return value - not used here
*/
static thread_return_type WINAPI call_auth_handle(void* context)
{
struct props_rc_parms* pr = (struct props_rc_parms*)context;

//(*(pr->m->auth_handle))(pr->m->auth_handle_context, pr->properties, pr->reasonCode);
abort(); //TODO: Implement for MQTTClient
MQTTProperties_free(pr->properties);
free(pr->properties);
free(pr);
return 0;
}


/* This is the thread function that handles the calling of callback functions if set */
static thread_return_type WINAPI MQTTClient_run(void* n)
{
Expand Down Expand Up @@ -953,17 +936,23 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
}
free(disc);
}
else if (pack->header.bits.type == AUTH && m->auth_handle)
else if (pack->header.bits.type == AUTH)
{
struct props_rc_parms dp;
Ack* disc = (Ack*)pack;

dp.m = m;
dp.properties = &disc->properties;
dp.reasonCode = disc->rc;
free(pack);
Log(TRACE_MIN, -1, "Calling auth_handle for client %s", m->c->clientID);
Thread_start(call_auth_handle, &dp);
Auth *auth = (Auth *)pack;
enum MQTTReasonCodes authrc = MQTTREASONCODE_SUCCESS;
MQTTClient_handleAuthData authHandleData = MQTTClient_handleAuthData_initializer;

authrc = MQTTClient_processAuth(m, auth->rc, &auth->properties,
&authHandleData);

rc = MQTTProtocol_handleAuth(pack, m->c->net.socket, authrc,
authHandleData.authDataOut.data,
authHandleData.authDataOut.len);
if (authHandleData.authDataOut.data)
free(authHandleData.authDataOut.data);
if (authrc != MQTTREASONCODE_SUCCESS &&
authrc != MQTTREASONCODE_CONTINUE_AUTHENTICATION)
MQTTClient_disconnect_internal(m, 0);
}
}
}
Expand Down Expand Up @@ -1421,7 +1410,18 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c
{
Connack* connack = (Connack*)pack;
Log(TRACE_PROTOCOL, 1, NULL, m->c->net.socket, m->c->clientID, connack->rc);
if ((rc = connack->rc) == MQTTCLIENT_SUCCESS)
rc = connack->rc;

if (rc == MQTTCLIENT_SUCCESS && m->c->authMethod)
{
MQTTClient_handleAuthData authHandleData = MQTTClient_handleAuthData_initializer;
rc = MQTTClient_processAuth(m, connack->rc, &connack->properties,
&authHandleData);
if (authHandleData.authDataOut.data)
free(authHandleData.authDataOut.data);
}

if (rc == MQTTCLIENT_SUCCESS)
{
m->c->connected = 1;
m->c->good = 1;
Expand Down Expand Up @@ -1724,7 +1724,8 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO
property.value.data.data = authData.authDataOut.data;
property.value.data.len = authData.authDataOut.len;
rc.reasonCode = MQTTProperties_add(connectProperties, &property);
free(authData.authDataOut.data);
if (authData.authDataOut.data)
free(authData.authDataOut.data);
if (rc.reasonCode)
goto exit;
}
Expand Down Expand Up @@ -3258,6 +3259,54 @@ int MQTTClient_setSelectInterface(MQTTClient handle, void* context, MQTTClient_s
return rc;
}

enum MQTTReasonCodes MQTTClient_processAuth(MQTTClients *m, int rc, MQTTProperties *props,
MQTTClient_handleAuthData *out)
{
MQTTProperty *authMethodProp = NULL;
char *authMethodIn = NULL;
int authMethodInLen = 0;
MQTTProperty *authData = NULL;

if (!m->c->authMethod)
{
return MQTTREASONCODE_PROTOCOL_ERROR;
}

authMethodProp = MQTTProperties_getProperty(props,
MQTTPROPERTY_CODE_AUTHENTICATION_METHOD);
if (authMethodProp)
{
authMethodIn = authMethodProp->value.data.data;
authMethodInLen = authMethodProp->value.data.len;
}

if ((rc == 0 && authMethodProp == NULL) ||
(m->c->authMethod && authMethodIn && authMethodInLen > 0 &&
strlen(m->c->authMethod) == authMethodInLen &&
memcmp(m->c->authMethod, authMethodIn, authMethodInLen) == 0))
{
if (m->auth_handle == NULL)
return MQTTREASONCODE_NOT_AUTHORIZED;

authData = MQTTProperties_getProperty(props,
MQTTPROPERTY_CODE_AUTHENTICATION_DATA);

out->reasonCode = rc;
if (authData && authMethodIn)
{
out->authDataIn.data = authData->value.data.data;
out->authDataIn.len = authData->value.data.len;
}

rc = (*(m->auth_handle))(m->auth_handle_context, out);
if (rc < 0)
return MQTTREASONCODE_NOT_AUTHORIZED;

if (rc > 0)
return MQTTREASONCODE_CONTINUE_AUTHENTICATION;

return MQTTREASONCODE_SUCCESS;
}

return MQTTREASONCODE_BAD_AUTHENTICATION_METHOD;
}
15 changes: 5 additions & 10 deletions src/MQTTClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,10 +523,14 @@ typedef struct
/**
* This is a callback function, which will be called when a MQTTv5 enhanced
* authentication packet is either received from the server or needs to be
* populated by the client. This applies to MQTT V5 and above only.
* populated by the client, it applies to MQTT V5 and above only.
* The callback is not executed on a dedicated thread, do not call other MQTTClient
* functions inside of it.
* @param context A pointer to the <i>context</i> value originally passed to
* ::MQTTClient_setHandleAuth(), which contains any application-specific context.
* @param data The MQTTClient_handleAuthData.
* @return a negative value to indicate not-authorized, a value of 0 to indicate success,
* a positive value to indicate continue.
*/
typedef int MQTTClient_handleAuth(void* context, MQTTClient_handleAuthData* data);

Expand Down Expand Up @@ -1453,15 +1457,6 @@ LIBMQTT_API int MQTTClient_receive(MQTTClient handle, char** topicName, int* top
*/
LIBMQTT_API void MQTTClient_freeMessage(MQTTClient_message** msg);

/**
* This function is used to allocate memory to be used or freed by the MQTT C client library,
* especially the data in the ::MQTTPersistence_afterRead and ::MQTTPersistence_beforeWrite
* callbacks. This is needed on Windows when the client library and application
* program have been compiled with different versions of the C compiler.
* @param size The size of the memory to be allocated.
*/
LIBMQTT_API void* MQTTClient_malloc(size_t size);

/**
* This function frees memory allocated by the MQTT C client library, especially the
* topic name. This is needed on Windows when the client libary and application
Expand Down

0 comments on commit cbe0350

Please sign in to comment.