Skip to content

Commit

Permalink
Ensure that a LocalEnv is only used in a thread it belongs to
Browse files Browse the repository at this point in the history
  • Loading branch information
arybczak committed Sep 28, 2024
1 parent 255b153 commit e778e23
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 79 deletions.
1 change: 1 addition & 0 deletions effectful-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
policy for the `NonDet` effect is selected.
* Add a `SeqForkUnlift` strategy to support running unlifting functions outside
of the scope of effects they capture.
* Ensure that a `LocalEnv` is only used in a thread it belongs to.
* **Breaking changes**:
- `localSeqLend`, `localLend`, `localSeqBorrow` and `localBorrow` now take a
list of effects instead of a single one.
Expand Down
140 changes: 73 additions & 67 deletions effectful-core/src/Effectful/Dispatch/Dynamic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ module Effectful.Dispatch.Dynamic
, HasCallStack
) where

import Control.Monad
import Control.Monad.IO.Unlift
import Data.Primitive.PrimArray
import GHC.Stack (HasCallStack)
import GHC.TypeLits
Expand Down Expand Up @@ -157,6 +155,7 @@ import Effectful.Internal.Utils
--
-- >>> import Control.Exception (IOException)
-- >>> import Control.Monad.Catch (catch)
-- >>> import Control.Monad.IO.Class
-- >>> import qualified System.IO as IO
--
-- >>> import Effectful.Error.Static
Expand Down Expand Up @@ -250,6 +249,7 @@ import Effectful.Internal.Utils
--
-- If we naively try to interpret it, we will run into trouble:
--
-- >>> import Control.Monad.IO.Class
-- >>> import GHC.Clock (getMonotonicTime)
--
-- >>> :{
Expand Down Expand Up @@ -490,6 +490,7 @@ reinterpretWith runHandlerEs m handler = reinterpret runHandlerEs handler m
-- type instance DispatchOf E = Dynamic
-- :}
--
-- >>> import Control.Monad.IO.Class
-- >>> :{
-- runE :: IOE :> es => Eff (E : es) a -> Eff es a
-- runE = interpret_ $ \Op -> liftIO (putStrLn "op")
Expand Down Expand Up @@ -712,6 +713,7 @@ localSeqUnlift
-- ^ Continuation with the unlifting function in scope.
-> Eff es a
localSeqUnlift (LocalEnv les) k = unsafeEff $ \es -> do
requireMatchingStorages es les
seqUnliftIO les $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localSeqUnlift #-}
Expand All @@ -725,7 +727,9 @@ localSeqUnliftIO
-> ((forall r. Eff localEs r -> IO r) -> IO a)
-- ^ Continuation with the unlifting function in scope.
-> Eff es a
localSeqUnliftIO (LocalEnv les) k = liftIO $ seqUnliftIO les k
localSeqUnliftIO (LocalEnv les) k = unsafeEff $ \es -> do
requireMatchingStorages es les
seqUnliftIO les k
{-# INLINE localSeqUnliftIO #-}

-- | Create a local unlifting function with the given strategy.
Expand All @@ -737,15 +741,14 @@ localUnlift
-> ((forall r. Eff localEs r -> Eff es r) -> Eff es a)
-- ^ Continuation with the unlifting function in scope.
-> Eff es a
localUnlift (LocalEnv les) strategy k = case strategy of
SeqUnlift -> unsafeEff $ \es -> do
seqUnliftIO les $ \unlift -> do
localUnlift (LocalEnv les) strategy k = unsafeEff $ \es -> do
requireMatchingStorages es les
case strategy of
SeqUnlift -> seqUnliftIO les $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> unsafeEff $ \es -> do
seqForkUnliftIO les $ \unlift -> do
SeqForkUnlift -> seqForkUnliftIO les $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> unsafeEff $ \es -> do
concUnliftIO les p l $ \unlift -> do
ConcUnlift p l -> concUnliftIO les p l $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localUnlift #-}

Expand All @@ -758,10 +761,12 @@ localUnliftIO
-> ((forall r. Eff localEs r -> IO r) -> IO a)
-- ^ Continuation with the unlifting function in scope.
-> Eff es a
localUnliftIO (LocalEnv les) strategy k = case strategy of
SeqUnlift -> liftIO $ seqUnliftIO les k
SeqForkUnlift -> liftIO $ seqForkUnliftIO les k
ConcUnlift p l -> liftIO $ concUnliftIO les p l k
localUnliftIO (LocalEnv les) strategy k = unsafeEff $ \es -> do
requireMatchingStorages es les
case strategy of
SeqUnlift -> seqUnliftIO les k
SeqForkUnlift -> seqForkUnliftIO les k
ConcUnlift p l -> concUnliftIO les p l k
{-# INLINE localUnliftIO #-}

----------------------------------------
Expand All @@ -778,9 +783,8 @@ localSeqLift
-> ((forall r. Eff es r -> Eff localEs r) -> Eff es a)
-- ^ Continuation with the lifting function in scope.
-> Eff es a
localSeqLift !_ k = unsafeEff $ \es -> do
-- The LocalEnv parameter is not used, but we need it to constraint the
-- localEs type variable. It's also strict so that callers don't cheat.
localSeqLift (LocalEnv les) k = unsafeEff $ \es -> do
requireMatchingStorages es les
seqUnliftIO es $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localSeqLift #-}
Expand All @@ -796,17 +800,14 @@ localLift
-> ((forall r. Eff es r -> Eff localEs r) -> Eff es a)
-- ^ Continuation with the lifting function in scope.
-> Eff es a
localLift !_ strategy k = case strategy of
-- The LocalEnv parameter is not used, but we need it to constraint the
-- localEs type variable. It's also strict so that callers don't cheat.
SeqUnlift -> unsafeEff $ \es -> do
seqUnliftIO es $ \unlift -> do
localLift (LocalEnv les) strategy k = unsafeEff $ \es -> do
requireMatchingStorages es les
case strategy of
SeqUnlift -> seqUnliftIO es $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> unsafeEff $ \es -> do
seqForkUnliftIO es $ \unlift -> do
SeqForkUnlift -> seqForkUnliftIO es $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> unsafeEff $ \es -> do
concUnliftIO es p l $ \unlift -> do
ConcUnlift p l -> concUnliftIO es p l $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localLift #-}

Expand All @@ -827,9 +828,8 @@ withLiftMap
-> ((forall a b. (Eff es a -> Eff es b) -> Eff localEs a -> Eff localEs b) -> Eff es r)
-- ^ Continuation with the lifting function in scope.
-> Eff es r
withLiftMap !_ k = unsafeEff $ \es -> do
-- The LocalEnv parameter is not used, but we need it to constraint the
-- localEs type variable. It's also strict so that callers don't cheat.
withLiftMap (LocalEnv les) k = unsafeEff $ \es -> do
requireMatchingStorages es les
(`unEff` es) $ k $ \mapEff m -> unsafeEff $ \localEs -> do
seqUnliftIO localEs $ \unlift -> do
(`unEff` es) . mapEff . unsafeEff_ $ unlift m
Expand Down Expand Up @@ -868,9 +868,8 @@ withLiftMapIO
-> ((forall a b. (IO a -> IO b) -> Eff localEs a -> Eff localEs b) -> Eff es r)
-- ^ Continuation with the lifting function in scope.
-> Eff es r
withLiftMapIO !_ k = k $ \mapIO m -> unsafeEff $ \es -> do
-- The LocalEnv parameter is not used, but we need it to constraint the
-- localEs type variable. It's also strict so that callers don't cheat.
withLiftMapIO (LocalEnv les) k = k $ \mapIO m -> unsafeEff $ \es -> do
requireMatchingStorages es les
seqUnliftIO es $ \unlift -> mapIO $ unlift m
{-# INLINE withLiftMapIO #-}

Expand All @@ -892,17 +891,16 @@ localLiftUnlift
-> ((forall r. Eff es r -> Eff localEs r) -> (forall r. Eff localEs r -> Eff es r) -> Eff es a)
-- ^ Continuation with the lifting and unlifting function in scope.
-> Eff es a
localLiftUnlift (LocalEnv les) strategy k = case strategy of
SeqUnlift -> unsafeEff $ \es -> do
seqUnliftIO es $ \unliftEs -> do
localLiftUnlift (LocalEnv les) strategy k = unsafeEff $ \es -> do
requireMatchingStorages es les
case strategy of
SeqUnlift -> seqUnliftIO es $ \unliftEs -> do
seqUnliftIO les $ \unliftLocalEs -> do
(`unEff` es) $ k (unsafeEff_ . unliftEs) (unsafeEff_ . unliftLocalEs)
SeqForkUnlift -> unsafeEff $ \es -> do
seqForkUnliftIO es $ \unliftEs -> do
SeqForkUnlift -> seqForkUnliftIO es $ \unliftEs -> do
seqForkUnliftIO les $ \unliftLocalEs -> do
(`unEff` es) $ k (unsafeEff_ . unliftEs) (unsafeEff_ . unliftLocalEs)
ConcUnlift p l -> unsafeEff $ \es -> do
concUnliftIO es p l $ \unliftEs -> do
ConcUnlift p l -> concUnliftIO es p l $ \unliftEs -> do
concUnliftIO les p l $ \unliftLocalEs -> do
(`unEff` es) $ k (unsafeEff_ . unliftEs) (unsafeEff_ . unliftLocalEs)
{-# INLINE localLiftUnlift #-}
Expand All @@ -923,10 +921,12 @@ localLiftUnliftIO
-> ((forall r. IO r -> Eff localEs r) -> (forall r. Eff localEs r -> IO r) -> IO a)
-- ^ Continuation with the lifting and unlifting function in scope.
-> Eff es a
localLiftUnliftIO (LocalEnv les) strategy k = case strategy of
SeqUnlift -> liftIO $ seqUnliftIO les $ k unsafeEff_
SeqForkUnlift -> liftIO $ seqForkUnliftIO les $ k unsafeEff_
ConcUnlift p l -> liftIO $ concUnliftIO les p l $ k unsafeEff_
localLiftUnliftIO (LocalEnv les) strategy k = unsafeEff $ \es -> do
requireMatchingStorages es les
case strategy of
SeqUnlift -> seqUnliftIO les $ k unsafeEff_
SeqForkUnlift -> seqForkUnliftIO les $ k unsafeEff_
ConcUnlift p l -> concUnliftIO les p l $ k unsafeEff_
{-# INLINE localLiftUnliftIO #-}

----------------------------------------
Expand Down Expand Up @@ -1000,16 +1000,15 @@ localLend
-> ((forall r. Eff (lentEs ++ localEs) r -> Eff localEs r) -> Eff es a)
-- ^ Continuation with the lent handler in scope.
-> Eff es a
localLend (LocalEnv les) strategy k = case strategy of
SeqUnlift -> unsafeEff $ \es -> do
eles <- copyRefs @lentEs es les
seqUnliftIO eles $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> unsafeEff $ \es -> do
eles <- copyRefs @lentEs es les
seqForkUnliftIO eles $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> unsafeEff $ \es -> do
eles <- copyRefs @lentEs es les
concUnliftIO eles p l $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
localLend (LocalEnv les) strategy k = unsafeEff $ \es -> do
eles <- copyRefs @lentEs es les
case strategy of
SeqUnlift -> seqUnliftIO eles $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> seqForkUnliftIO eles $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> concUnliftIO eles p l $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localLend #-}

-- | Borrow effects from the local environment.
Expand Down Expand Up @@ -1041,16 +1040,15 @@ localBorrow
-> ((forall r. Eff (borrowedEs ++ es) r -> Eff es r) -> Eff es a)
-- ^ Continuation with the borrowed handler in scope.
-> Eff es a
localBorrow (LocalEnv les) strategy k = case strategy of
SeqUnlift -> unsafeEff $ \es -> do
ees <- copyRefs @borrowedEs les es
seqUnliftIO ees $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> unsafeEff $ \es -> do
ees <- copyRefs @borrowedEs les es
seqForkUnliftIO ees $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> unsafeEff $ \es -> do
ees <- copyRefs @borrowedEs les es
concUnliftIO ees p l $ \unlift -> (`unEff` es) $ k $ unsafeEff_ . unlift
localBorrow (LocalEnv les) strategy k = unsafeEff $ \es -> do
ees <- copyRefs @borrowedEs les es
case strategy of
SeqUnlift -> seqUnliftIO ees $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
SeqForkUnlift -> seqForkUnliftIO ees $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
ConcUnlift p l -> concUnliftIO ees p l $ \unlift -> do
(`unEff` es) $ k $ unsafeEff_ . unlift
{-# INLINE localBorrow #-}

copyRefs
Expand All @@ -1059,9 +1057,8 @@ copyRefs
=> Env srcEs
-> Env destEs
-> IO (Env (es ++ destEs))
copyRefs (Env soffset srefs sstorage) (Env doffset drefs dstorage) = do
when (sstorage /= dstorage) $ do
error "storages do not match"
copyRefs src@(Env soffset srefs _) dest@(Env doffset drefs storage) = do
requireMatchingStorages src dest
let size = sizeofPrimArray drefs - doffset
es = reifyIndices @es @srcEs
esSize = 2 * length es
Expand All @@ -1076,9 +1073,18 @@ copyRefs (Env soffset srefs sstorage) (Env doffset drefs dstorage) = do
writeRefs (i + 2) xs
writeRefs 0 es
refs <- unsafeFreezePrimArray mrefs
pure $ Env 0 refs dstorage
pure $ Env 0 refs storage
{-# NOINLINE copyRefs #-}

requireMatchingStorages :: HasCallStack => Env es1 -> Env es2 -> IO ()
requireMatchingStorages es1 es2
| envStorage es1 /= envStorage es2 = error
$ "Env and LocalEnv point to different Storages.\n"
++ "If you passed LocalEnv to a different thread and tried to create an "
++ "unlifting function there, it's not allowed. You need to create it in "
++ "the thread of the effect handler."
| otherwise = pure ()

-- | Require that both effect stacks share an opaque suffix.
--
-- Functions from the 'localUnlift' family utilize this constraint to guarantee
Expand Down
1 change: 1 addition & 0 deletions effectful/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
policy for the `NonDet` effect is selected.
* Add a `SeqForkUnlift` strategy to support running unlifting functions outside
of the scope of effects they capture.
* Ensure that a `LocalEnv` is only used in a thread it belongs to.
* **Breaking changes**:
- `localSeqLend`, `localLend`, `localSeqBorrow` and `localBorrow` now take a
list of effects instead of a single one.
Expand Down
35 changes: 35 additions & 0 deletions effectful/tests/ConcurrencyTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Test.Tasty.HUnit
import UnliftIO

import Effectful
import Effectful.Concurrent.Async qualified as E
import Effectful.Dispatch.Dynamic
import Effectful.Error.Static
import Effectful.State.Dynamic
import Utils qualified as U
Expand All @@ -20,6 +22,8 @@ concurrencyTests = testGroup "Concurrency"
, testCase "unlifting several times" test_unliftMany
, testCase "async with unmask" test_asyncWithUnmask
, testCase "pooled workers" test_pooledWorkers
, testCase "using local unlift correctly works" test_correctLocalUnlift
, testCase "using local unlift incorrectly doesn't work" test_wrongLocalUnlift
]

test_localState :: Assertion
Expand Down Expand Up @@ -119,3 +123,34 @@ test_pooledWorkers = runEff . evalStateLocal (0::Int) $ do
where
n = 10
threads = 4

test_correctLocalUnlift :: Assertion
test_correctLocalUnlift = runEff . E.runConcurrent $ do
x <- runFork . send . RunAsyncCorrect $ pure ()
E.wait x

test_wrongLocalUnlift :: Assertion
test_wrongLocalUnlift = runEff . E.runConcurrent $ do
U.assertThrowsErrorCall "invalid LocalEnv use" $ do
x <- runFork . send . RunAsyncWrong $ pure ()
E.wait x

data Fork :: Effect where
RunAsyncCorrect :: m a -> Fork m (E.Async a)
RunAsyncWrong :: m a -> Fork m (E.Async a)
type instance DispatchOf Fork = Dynamic

runFork :: (IOE :> es, E.Concurrent :> es) => Eff (Fork : es) a -> Eff es a
runFork = interpret $ \env -> \case
RunAsyncCorrect action -> do
-- LocalEnv is correctly used in the thread in belongs to, so creation of
-- the unlifting function should succeed.
localUnlift env strategy $ \unlift -> do
E.async $ unlift action
RunAsyncWrong action -> E.async $ do
-- LocalEnv is incorrectly passed to a different thread, so creation of the
-- unlifting function should fail.
localUnlift env strategy $ \unlift -> do
unlift action
where
strategy = ConcUnlift Ephemeral $ Limited 1
Loading

0 comments on commit e778e23

Please sign in to comment.