diff --git a/projects/premake5.lua b/projects/premake5.lua index 426289f..d97a8cc 100644 --- a/projects/premake5.lua +++ b/projects/premake5.lua @@ -4,7 +4,14 @@ newoption({ value = "path to garrysmod_common directory" }) +newoption({ + trigger = "whitelist", + description = "Enables wrapping of getaddrinfo and a whitelist file to filter valid addresses and ports", + value = "0 or 1" +}) + local gmcommon = _OPTIONS.gmcommon or os.getenv("GARRYSMOD_COMMON") +local whitelist = _OPTIONS.whitelist == "1" if gmcommon == nil then error("you didn't provide a path to your garrysmod_common (https://github.com/danielga/garrysmod_common) directory") end @@ -16,6 +23,7 @@ local LUASOCKET_FOLDER = "../luasocket/src" CreateWorkspace({name = "socket.core"}) CreateProject({serverside = true, manual_files = true}) files("../source/socket.cpp") + if whitelist then files("../source/whitelist.cpp") defines({"USE_WHITELIST"}) includedirs(LUASOCKET_FOLDER) end IncludeLuaShared() links("socket") @@ -24,6 +32,7 @@ CreateWorkspace({name = "socket.core"}) CreateProject({serverside = false, manual_files = true}) files("../source/socket.cpp") + if whitelist then files("../source/whitelist.cpp") defines({"USE_WHITELIST"}) includedirs(LUASOCKET_FOLDER) end IncludeLuaShared() links("socket") @@ -56,6 +65,7 @@ CreateWorkspace({name = "socket.core"}) "LUASOCKET_API=__declspec(dllexport)", "MIME_API=__declspec(dllexport)" }) + if whitelist then defines({"getaddrinfo=__wrap_getaddrinfo"}) end files(LUASOCKET_FOLDER .. "/wsocket.c") links("ws2_32") @@ -65,6 +75,7 @@ CreateWorkspace({name = "socket.core"}) "UNIX_API=''", "MIME_API=''" }) + if whitelist then defines({"getaddrinfo=__wrap_getaddrinfo"}) end files(LUASOCKET_FOLDER .. "/usocket.c") CreateWorkspace({name = "mime.core"}) diff --git a/source/socket.cpp b/source/socket.cpp index 4dbf16e..c0fdd12 100644 --- a/source/socket.cpp +++ b/source/socket.cpp @@ -2,8 +2,33 @@ extern "C" int luaopen_socket_core( lua_State *state ); +int parseWhitelist(); +void clearWhitelist(); +enum : int +{ + PARSE_SUCCESS = 0, + PARSE_CANT_READ = 1, + PARSE_NO_ENTRIES = 2 +}; + GMOD_MODULE_OPEN( ) { + #ifdef USE_WHITELIST + switch (parseWhitelist()) + { + case PARSE_SUCCESS: + break; + case PARSE_CANT_READ: + LUA->ThrowError("Failed to read whitelist file!"); + break; + case PARSE_NO_ENTRIES: + LUA->ThrowError("Didn't find any valid entries in whitelist file!"); + break; + default: + break; + } + #endif + if( luaopen_socket_core( LUA->GetState( ) ) == 1 ) { LUA->Push( -1 ); @@ -15,6 +40,10 @@ GMOD_MODULE_OPEN( ) GMOD_MODULE_CLOSE( ) { + #ifdef USE_WHITELIST + clearWhitelist(); + #endif + LUA->PushNil( ); LUA->SetField( GarrysMod::Lua::INDEX_GLOBAL, "socket" ); return 0; diff --git a/source/whitelist.cpp b/source/whitelist.cpp new file mode 100644 index 0000000..43b89df --- /dev/null +++ b/source/whitelist.cpp @@ -0,0 +1,109 @@ +#undef getaddrinfo + +#include "socket.h" +#include +#include +#include +#include +#include + +//Somewhere glua can't read? +const char* whitelistDir = "../gm_socket_whitelist.txt"; +std::map > whitelist; + +enum : int +{ + PARSE_SUCCESS = 0, + PARSE_CANT_READ = 1, + PARSE_NO_ENTRIES = 2 +}; + +int parseWhitelist() +{ + std::ifstream input(whitelistDir); + if (input) + { + std::stringstream filereader; + filereader << input.rdbuf(); + std::string filedata = filereader.str(); + std::regex line_parser("(?:(?!\r?\n).)+"); + std::regex entry_parser("^[ \\t]*([\\w\\.\\*-]+)\\:(\\d+)[ \\t]*$"); + std::regex wildcard("\\*"); + std::regex dot("\\."); + for (std::sregex_iterator line = std::sregex_iterator(filedata.begin(), filedata.end(), line_parser), end = std::sregex_iterator(); line != end; ++line) + { + const std::string& linestr = line->operator[](0); + std::smatch match; + if(std::regex_match(linestr, match, entry_parser)) + { + std::string domain = match[1]; + domain = std::regex_replace(domain, wildcard, "[\\w-]+"); + domain = std::regex_replace(domain, dot, "\\."); + whitelist[match[2].str()].push_back(std::regex(domain)); + } + } + if (whitelist.empty()) + { + return PARSE_NO_ENTRIES; + } + } + else + { + return PARSE_CANT_READ; + } + return PARSE_SUCCESS; +} + +void clearWhitelist() +{ + whitelist.clear(); +} + +bool isSafe(const std::string& pNodeName, const std::string& pServiceName) +{ + std::map >::iterator domains = whitelist.find(pServiceName); + if (domains != whitelist.end()) + { + for (auto i = domains->second.begin(), end = domains->second.end(); i != end; ++i) + { + if (std::regex_match(pNodeName, *i)) + { + return true; + } + } + return false; + } + else + { + return false; + } +} + +extern "C" { + +#ifdef _WIN32 + INT WSAAPI __wrap_getaddrinfo( + _In_opt_ PCSTR pNodeName, + _In_opt_ PCSTR pServiceName, + _In_opt_ const ADDRINFOA * pHints, + _Outptr_result_maybenull_ PADDRINFOA * ppResult + ) +#else + int __wrap_getaddrinfo (__const char *__restrict pNodeName, + __const char *__restrict pServiceName, + __const struct addrinfo *__restrict pHints, + struct addrinfo **__restrict ppResult) +#endif + { + if(isSafe(pNodeName, pServiceName)) + { + return getaddrinfo(pNodeName, pServiceName, pHints, ppResult); + } + else + { + *ppResult = nullptr; + return EAI_FAIL; + } + } + +}