Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CmsgIdFd -> CmsgIdFds #575

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions Network/Socket/Win32/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -70,8 +71,8 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO)
-- | Control message ID for POSIX file-descriptor passing.
--
-- Not supported on Windows; use WSADuplicateSocket instead
pattern CmsgIdFd :: CmsgId
pattern CmsgIdFd = CmsgId (-1) (-1)
pattern CmsgIdFds :: CmsgId
pattern CmsgIdFds = CmsgId (-1) (-1)

----------------------------------------------------------------

Expand All @@ -91,11 +92,13 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
----------------------------------------------------------------

-- | A class to encode and decode control message.
class Storable a => ControlMessage a where
class ControlMessage a where
controlMessageId :: CmsgId
encodeCmsg :: a -> Cmsg
decodeCmsg :: Cmsg -> Maybe a

encodeCmsg :: forall a. ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
encodeStorableCmsg :: forall a. (ControlMessage a, Storable a) => a -> Cmsg
encodeStorableCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
Expand All @@ -104,8 +107,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do
where
siz = sizeOf x

decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeStorableCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
Expand All @@ -122,6 +125,8 @@ newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TTL where
controlMessageId = CmsgIdIPv4TTL
decodeCmsg = decodeStorableCmsg
encodeCmsg = encodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -130,6 +135,8 @@ newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId = CmsgIdIPv6HopLimit
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -138,6 +145,8 @@ newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId = CmsgIdIPv4TOS
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -146,6 +155,8 @@ newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId = CmsgIdIPv6TClass
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -158,6 +169,8 @@ instance Show IPv4PktInfo where

instance ControlMessage IPv4PktInfo where
controlMessageId = CmsgIdIPv4PktInfo
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

instance Storable IPv4PktInfo where
sizeOf ~_ = #{size IN_PKTINFO}
Expand All @@ -180,6 +193,8 @@ instance Show IPv6PktInfo where

instance ControlMessage IPv6PktInfo where
controlMessageId = CmsgIdIPv6PktInfo
decodeCmsg = decodeStorableCmsg
encodeCmsg = encodeStorableCmsg

instance Storable IPv6PktInfo where
sizeOf ~_ = #{size IN6_PKTINFO}
Expand All @@ -192,8 +207,14 @@ instance Storable IPv6PktInfo where
n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p
return $ IPv6PktInfo (fromIntegral n) ha6

instance ControlMessage Fd where
controlMessageId = CmsgIdFd
----------------------------------------------------------------

instance ControlMessage [Fd] where
controlMessageId = CmsgIdFds
encodeCmsg = \_ -> Cmsg CmsgIdFds ""
decodeCmsg = \_ -> Just []

----------------------------------------------------------------

cmsgIdBijection :: Bijection CmsgId String
cmsgIdBijection =
Expand All @@ -204,7 +225,7 @@ cmsgIdBijection =
, (CmsgIdIPv6TClass, "CmsgIdIPv6TClass")
, (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo")
, (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo")
, (CmsgIdFd, "CmsgIdFd")
, (CmsgIdFds, "CmsgIdFds")
]

instance Show CmsgId where
Expand Down
3 changes: 3 additions & 0 deletions network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,6 @@ test-suite spec

if impl(ghc >=8)
default-extensions: Strict StrictData

if os(windows)
cpp-options: -D_WIN32
2 changes: 2 additions & 0 deletions tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@
let msgid = CmsgId (-300) (-300) in
show msgid `shouldBe` "CmsgId (-300) (-300)"

#if !defined(_WIN32)
describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do
it "holds for [Fd]" $ forAll genFds $
\x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd])
#endif

describe "bijective read-show roundtrip equality" $ do
it "holds for Family" $ forAll familyGen $
Expand Down Expand Up @@ -421,7 +423,7 @@
cmsgidGen = biasedGen (\g -> CmsgId <$> g <*> g) cmsgidPatterns arbitrary

genFds :: Gen [Fd]
genFds = listOf (Fd <$> arbitrary)

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 8.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.4)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.6)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

Defined but not used: ‘genFds’

Check warning on line 426 in tests/Network/SocketSpec.hs

View workflow job for this annotation

GitHub Actions / build (windows-latest, 9.8)

Defined but not used: ‘genFds’

-- pruned lists of pattern synonym values for each type to generate values from

Expand Down
Loading