{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE UndecidableInstances #-}
{-|
Module: IHP.Controller.Session
Description: Functions to work with session cookies, provides 'setSession', 'getSession' and friends
Copyright: (c) digitally induced GmbH, 2021

The session provides a way for your application to store small amounts of information that will be persisted between requests. It’s mainly used from inside your controller actions.

In general, you should not store complex data structures in the session. It’s better to store scalar values in there only. For example: Store the current user-id instead of the current user record.

The session works by storing the data inside a cryptographically signed and encrypted cookie on the client. The encryption key is generated automatically and is stored at @Config/client_session_key.aes@. Internally IHP uses the clientsession library. You can find more technical details on the implementation in the <https://hackage.haskell.org/package/clientsession-0.9.1.2/docs/Web-ClientSession.html clientsession> documentation.

The cookie @max-age@ is set to 30 days by default. To protect against CSRF, the @SameSite@ Policy is set to @Lax@.
-}
module IHP.Controller.Session
  (
  -- * Session Error
  SessionError (..)

  -- * Interacting with session store
  , setSession
  , getSession
  , getSessionEither
  , deleteSession
  , getSessionAndClear
  ) where

import IHP.Prelude
import IHP.Controller.RequestContext
import IHP.Controller.Context
import IHP.ModelSupport
import qualified Data.UUID as UUID
import qualified Data.Vault.Lazy as Vault
import qualified Network.Wai as Wai
import qualified Data.Serialize as Serialize
import Data.Serialize (Serialize)
import Data.Serialize.Text ()

-- | Types of possible errors as a result of
-- requesting a value from the session storage
data SessionError
    -- | Value not found in the session storage
    = NotFoundError
    -- | Error occurce during parsing value
    | ParseError String
    deriving (Int -> SessionError -> ShowS
[SessionError] -> ShowS
SessionError -> String
(Int -> SessionError -> ShowS)
-> (SessionError -> String)
-> ([SessionError] -> ShowS)
-> Show SessionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SessionError -> ShowS
showsPrec :: Int -> SessionError -> ShowS
$cshow :: SessionError -> String
show :: SessionError -> String
$cshowList :: [SessionError] -> ShowS
showList :: [SessionError] -> ShowS
Show, SessionError -> SessionError -> Bool
(SessionError -> SessionError -> Bool)
-> (SessionError -> SessionError -> Bool) -> Eq SessionError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SessionError -> SessionError -> Bool
== :: SessionError -> SessionError -> Bool
$c/= :: SessionError -> SessionError -> Bool
/= :: SessionError -> SessionError -> Bool
Eq)

-- | Stores a value inside the session:
--
-- > action SessionExampleAction { userId } = do
-- >     setSession "userId" userId
--
-- For cases where setSession is used with literals,
-- to avoid type ambiguity, you can use one of the options below
--
-- __Example:__ Annotate a literal with a type
--
-- > action LogoutAction = do
-- >     setSession "userEmail" ("hi@digitallyinduced.com" :: Text)
--
-- __Example:__ Using setSession with type application
--
-- > action LogoutAction = do
-- >     setSession @Text "userEmail" "hi@digitallyinduced.com"
--
setSession :: (?context :: ControllerContext, Serialize value)
           => ByteString -> value -> IO ()
setSession :: forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> value -> IO ()
setSession ByteString
name value
value = (?context::ControllerContext) => ByteString -> ByteString -> IO ()
ByteString -> ByteString -> IO ()
sessionInsert ByteString
name (value -> ByteString
forall a. Serialize a => a -> ByteString
Serialize.encode value
value)
{-# INLINABLE setSession #-}

-- | Retrives a value from the session:
--
-- > action SessionExampleAction = do
-- >     userEmail <- getSession @Text "userEmail"
-- >     counter <- getSession @Int "counter"
-- >     userId <- getSession @(Id User) "userId"
--
-- @userEmail@ is set to @Just' "hi@digitallyinduced.com"@
-- when the value has been set before. Otherwise, it will be 'Nothing'.
--
-- If an error occurs while getting the value, the result will be 'Nothing'.
getSession :: forall value
            . (?context :: ControllerContext, Serialize value)
           => ByteString -> IO (Maybe value)
getSession :: forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> IO (Maybe value)
getSession ByteString
name = ByteString -> IO (Either SessionError value)
forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> IO (Either SessionError value)
getSessionEither ByteString
name IO (Either SessionError value)
-> (Either SessionError value -> IO (Maybe value))
-> IO (Maybe value)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left SessionError
_ -> Maybe value -> IO (Maybe value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe value
forall a. Maybe a
Nothing
    Right value
result -> Maybe value -> IO (Maybe value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (value -> Maybe value
forall a. a -> Maybe a
Just value
result)
{-# INLINABLE getSession #-}

-- | Retrives a value from the session:
--
-- 'getSession' variant, which returns 'SessionError' if an error occurs
-- while getting value from session storage
--
-- > action SessionExampleAction = do
-- >     counter <- getSessionEither @Int "counter"
-- >     case counter of
-- >         Right value -> ...
-- >         Left (ParseError errorMessage) -> ...
-- >         Left NotFoundError -> ...
-- >         Left VaultError -> ...
getSessionEither :: forall value
            . (?context :: ControllerContext, Serialize value)
           => ByteString -> IO (Either SessionError value)
getSessionEither :: forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> IO (Either SessionError value)
getSessionEither ByteString
name = (?context::ControllerContext) =>
ByteString -> IO (Maybe ByteString)
ByteString -> IO (Maybe ByteString)
sessionLookup ByteString
name IO (Maybe ByteString)
-> (Maybe ByteString -> IO (Either SessionError value))
-> IO (Either SessionError value)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe ByteString
Nothing -> Either SessionError value -> IO (Either SessionError value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionError value -> IO (Either SessionError value))
-> Either SessionError value -> IO (Either SessionError value)
forall a b. (a -> b) -> a -> b
$ SessionError -> Either SessionError value
forall a b. a -> Either a b
Left SessionError
NotFoundError
        Just ByteString
"" -> Either SessionError value -> IO (Either SessionError value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionError value -> IO (Either SessionError value))
-> Either SessionError value -> IO (Either SessionError value)
forall a b. (a -> b) -> a -> b
$ SessionError -> Either SessionError value
forall a b. a -> Either a b
Left SessionError
NotFoundError
        Just ByteString
stringValue -> case ByteString -> Either String value
forall a. Serialize a => ByteString -> Either String a
Serialize.decode ByteString
stringValue of
            Left String
error -> Either SessionError value -> IO (Either SessionError value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionError value -> IO (Either SessionError value))
-> (SessionError -> Either SessionError value)
-> SessionError
-> IO (Either SessionError value)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SessionError -> Either SessionError value
forall a b. a -> Either a b
Left (SessionError -> IO (Either SessionError value))
-> SessionError -> IO (Either SessionError value)
forall a b. (a -> b) -> a -> b
$ String -> SessionError
ParseError String
error
            Right value
value -> Either SessionError value -> IO (Either SessionError value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionError value -> IO (Either SessionError value))
-> Either SessionError value -> IO (Either SessionError value)
forall a b. (a -> b) -> a -> b
$ value -> Either SessionError value
forall a b. b -> Either a b
Right value
value
{-# INLINABLE getSessionEither #-}

-- | Remove session values from storage:
--
-- __Example:__ Deleting a @userId@ field from the session
--
-- > action LogoutAction = do
-- >     deleteSession "userId"
--
-- __Example:__ Calling 'getSession' after
-- using 'deleteSession' will return @Nothing@
--
-- > setSession "userId" (1337 :: Int)
-- > userId <- getSession @Int "userId" -- Returns: Just 1337
-- >
-- > deleteSession "userId"
-- > userId <- getSession @Int "userId" -- Returns: Nothing
deleteSession :: (?context :: ControllerContext) => ByteString -> IO ()
deleteSession :: (?context::ControllerContext) => ByteString -> IO ()
deleteSession ByteString
name = (?context::ControllerContext) => ByteString -> ByteString -> IO ()
ByteString -> ByteString -> IO ()
sessionInsert ByteString
name ByteString
""

-- | Returns a value from the session, and deletes it after retrieving:
--
-- > action SessionExampleAction = do
-- >     notification <- getSessionAndClear @Text "notification"
getSessionAndClear :: forall value
                    . (?context :: ControllerContext, Serialize value)
                   => ByteString -> IO (Maybe value)
getSessionAndClear :: forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> IO (Maybe value)
getSessionAndClear ByteString
name = do
    Maybe value
value <- forall value.
(?context::ControllerContext, Serialize value) =>
ByteString -> IO (Maybe value)
getSession @value ByteString
name
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe value -> Bool
forall a. Maybe a -> Bool
isJust Maybe value
value) ((?context::ControllerContext) => ByteString -> IO ()
ByteString -> IO ()
deleteSession ByteString
name)
    Maybe value -> IO (Maybe value)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe value
value
{-# INLINABLE getSessionAndClear #-}

instance (PrimaryKey table ~ UUID) => Serialize (Id' table) where
    put :: Putter (Id' table)
put (Id PrimaryKey table
value) = Putter ByteString
forall t. Serialize t => Putter t
Serialize.put (UUID -> ByteString
UUID.toASCIIBytes UUID
PrimaryKey table
value)
    get :: Get (Id' table)
get = do
        Maybe UUID
maybeUUID <- ByteString -> Maybe UUID
UUID.fromASCIIBytes (ByteString -> Maybe UUID) -> Get ByteString -> Get (Maybe UUID)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
forall t. Serialize t => Get t
Serialize.get
        case Maybe UUID
maybeUUID of
            Maybe UUID
Nothing -> String -> Get (Id' table)
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Failed to parse UUID"
            Just UUID
uuid -> Id' table -> Get (Id' table)
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimaryKey table -> Id' table
forall (table :: Symbol). PrimaryKey table -> Id' table
Id UUID
PrimaryKey table
uuid)

sessionInsert :: (?context :: ControllerContext) => ByteString -> ByteString -> IO ()
sessionInsert :: (?context::ControllerContext) => ByteString -> ByteString -> IO ()
sessionInsert = (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
-> ByteString -> ByteString -> IO ()
forall a b. (a, b) -> b
snd (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
(?context::ControllerContext) =>
(ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
sessionVault

sessionLookup :: (?context :: ControllerContext) => ByteString -> IO (Maybe ByteString)
sessionLookup :: (?context::ControllerContext) =>
ByteString -> IO (Maybe ByteString)
sessionLookup = (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
-> ByteString -> IO (Maybe ByteString)
forall a b. (a, b) -> a
fst (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
(?context::ControllerContext) =>
(ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
sessionVault

sessionVault :: (?context :: ControllerContext) => (ByteString -> IO (Maybe ByteString), ByteString -> ByteString -> IO ())
sessionVault :: (?context::ControllerContext) =>
(ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
sessionVault = case Maybe
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
vaultLookup of
        Just (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
session -> (ByteString -> IO (Maybe ByteString),
 ByteString -> ByteString -> IO ())
session
        Maybe
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
Nothing -> Text
-> (ByteString -> IO (Maybe ByteString),
    ByteString -> ByteString -> IO ())
forall a. Text -> a
error Text
"sessionInsert: The session vault is missing in the request"
    where
        RequestContext { Request
request :: Request
$sel:request:RequestContext :: RequestContext -> Request
request, Key
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
vault :: Key
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
$sel:vault:RequestContext :: RequestContext
-> Key
     (ByteString -> IO (Maybe ByteString),
      ByteString -> ByteString -> IO ())
vault } = ?context::ControllerContext
ControllerContext
?context.requestContext
        vaultLookup :: Maybe
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
vaultLookup = Key
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
-> Vault
-> Maybe
     (ByteString -> IO (Maybe ByteString),
      ByteString -> ByteString -> IO ())
forall a. Key a -> Vault -> Maybe a
Vault.lookup Key
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
vault (Request -> Vault
Wai.vault Request
request)