{-# LANGUAGE AllowAmbiguousTypes #-}

module IHP.LoginSupport.Middleware
    ( authMiddleware
    , adminAuthMiddleware
    , userIdMiddleware
    , adminIdMiddleware
    , fetchUserMiddleware
    , fetchAdminMiddleware
    , fetchUserMiddlewareFor
    , parseSessionUUID
    , authMiddlewareWith
    , currentUserVaultKey
    , currentAdminVaultKey
    , currentUserIdVaultKey
    , currentAdminIdVaultKey
    , lookupAuthVault
    ) where

import IHP.Prelude
import IHP.LoginSupport.Types
import IHP.LoginSupport.Helper.Controller (sessionKey)
import IHP.Controller.Session
import IHP.QueryBuilder
import IHP.Fetch
import IHP.ModelSupport
import IHP.Hasql.FromRow (FromRowHasql)
import qualified Network.Wai as Wai
import qualified Data.Vault.Lazy as Vault
import qualified Data.UUID as UUID
import qualified Data.ByteString as BS

-- | Middleware that reads a userId from the session and stores it in
-- 'currentUserIdVaultKey'. No database query is performed.
--
-- This is useful when you only need the user's UUID (e.g. for row-level
-- security) and want to avoid the cost of a database fetch.
--
-- > option $ AuthMiddleware (userIdMiddleware (sessionKey @User))
--
-- For full user record access, compose with 'fetchUserMiddleware':
--
-- > option $ AuthMiddleware (userIdMiddleware (sessionKey @User) . fetchUserMiddleware @User)
--
userIdMiddleware :: ByteString -> Wai.Middleware
userIdMiddleware :: ByteString -> Middleware
userIdMiddleware ByteString
sessionKeyName = ByteString -> Key (Maybe UUID) -> Middleware
userIdMiddlewareFor ByteString
sessionKeyName Key (Maybe UUID)
currentUserIdVaultKey
{-# INLINE userIdMiddleware #-}

-- | Same as 'userIdMiddleware' but stores the admin ID in 'currentAdminIdVaultKey'.
--
-- > option $ AuthMiddleware (adminIdMiddleware (sessionKey @Admin))
--
adminIdMiddleware :: ByteString -> Wai.Middleware
adminIdMiddleware :: ByteString -> Middleware
adminIdMiddleware ByteString
sessionKeyName = ByteString -> Key (Maybe UUID) -> Middleware
userIdMiddlewareFor ByteString
sessionKeyName Key (Maybe UUID)
currentAdminIdVaultKey
{-# INLINE adminIdMiddleware #-}

-- | Building block: reads a session key and stores the parsed UUID in the given vault key.
userIdMiddlewareFor :: ByteString -> Vault.Key (Maybe UUID) -> Wai.Middleware
userIdMiddlewareFor :: ByteString -> Key (Maybe UUID) -> Middleware
userIdMiddlewareFor ByteString
sessionKeyName Key (Maybe UUID)
idKey Application
app Request
req Response -> IO ResponseReceived
respond = do
    userId <- case Request
-> Maybe
     (ByteString -> IO (Maybe ByteString),
      ByteString -> ByteString -> IO ())
lookupSessionVault Request
req of
        Just (ByteString -> IO (Maybe ByteString)
lookupFn, ByteString -> ByteString -> IO ()
_) -> do
            rawValue <- ByteString -> IO (Maybe ByteString)
lookupFn ByteString
sessionKeyName
            pure $ case rawValue of
                Maybe ByteString
Nothing -> Maybe UUID
forall a. Maybe a
Nothing
                Just ByteString
"" -> Maybe UUID
forall a. Maybe a
Nothing
                Just ByteString
bs -> ByteString -> Maybe UUID
parseSessionUUID ByteString
bs
        Maybe
  (ByteString -> IO (Maybe ByteString),
   ByteString -> ByteString -> IO ())
Nothing -> Maybe UUID -> IO (Maybe UUID)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe UUID
forall a. Maybe a
Nothing
    let req' = Request
req { Wai.vault = Vault.insert idKey userId (Wai.vault req) }
    app req' respond
{-# INLINE userIdMiddlewareFor #-}

-- | Parse UUID from session bytes. Handles both:
--
--   - New format: raw 36-byte UUID ASCII (e.g. \"550e8400-e29b-41d4-a716-446655440000\")
--   - Old format: 8-byte cereal length prefix + 36-byte UUID ASCII (44 bytes total)
--
-- The old format comes from sessions written with @Serialize (Id' table)@ which
-- prepends an 8-byte big-endian length prefix via cereal. We support both formats
-- so existing sessions continue to work without logging users out on upgrade.
--
-- TODO: Remove old format support after 2026-05-01. At that point all
-- session cookies using the cereal encoding will have expired.
parseSessionUUID :: ByteString -> Maybe UUID
parseSessionUUID :: ByteString -> Maybe UUID
parseSessionUUID ByteString
bs
    | Just UUID
uuid <- ByteString -> Maybe UUID
UUID.fromASCIIBytes ByteString
bs = UUID -> Maybe UUID
forall a. a -> Maybe a
Just UUID
uuid
    | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
44 = ByteString -> Maybe UUID
UUID.fromASCIIBytes (Int -> ByteString -> ByteString
BS.drop Int
8 ByteString
bs)
    | Bool
otherwise = Maybe UUID
forall a. Maybe a
Nothing
{-# INLINE parseSessionUUID #-}

-- | Middleware that reads the userId from 'currentUserIdVaultKey', fetches
-- the full user record from the database, and stores it in 'currentUserVaultKey'.
--
-- Must be composed after 'userIdMiddleware':
--
-- > userIdMiddleware (sessionKey @User) . fetchUserMiddleware @User
--
fetchUserMiddleware :: forall user normalizedModel.
    ( normalizedModel ~ NormalizeModel user
    , normalizedModel ~ CurrentUserRecord
    , Typeable normalizedModel
    , Table normalizedModel
    , FromRowHasql normalizedModel
    , PrimaryKey (GetTableName normalizedModel) ~ UUID
    , GetTableName normalizedModel ~ GetTableName user
    , FilterPrimaryKey (GetTableName normalizedModel)
    ) => Wai.Middleware
fetchUserMiddleware :: forall user normalizedModel.
(normalizedModel ~ NormalizeModel user,
 normalizedModel ~ CurrentUserRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Middleware
fetchUserMiddleware = forall user normalizedModel.
(normalizedModel ~ NormalizeModel user, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Key (Maybe UUID) -> Key (Maybe normalizedModel) -> Middleware
fetchUserMiddlewareFor @user Key (Maybe UUID)
currentUserIdVaultKey Key (Maybe normalizedModel)
Key (Maybe CurrentUserRecord)
currentUserVaultKey
{-# INLINE fetchUserMiddleware #-}

-- | Middleware that reads the adminId from 'currentAdminIdVaultKey', fetches
-- the full admin record from the database, and stores it in 'currentAdminVaultKey'.
--
-- Must be composed after 'adminIdMiddleware':
--
-- > adminIdMiddleware (sessionKey @Admin) . fetchAdminMiddleware @Admin
--
fetchAdminMiddleware :: forall admin normalizedModel.
    ( normalizedModel ~ NormalizeModel admin
    , normalizedModel ~ CurrentAdminRecord
    , Typeable normalizedModel
    , Table normalizedModel
    , FromRowHasql normalizedModel
    , PrimaryKey (GetTableName normalizedModel) ~ UUID
    , GetTableName normalizedModel ~ GetTableName admin
    , FilterPrimaryKey (GetTableName normalizedModel)
    ) => Wai.Middleware
fetchAdminMiddleware :: forall admin normalizedModel.
(normalizedModel ~ NormalizeModel admin,
 normalizedModel ~ CurrentAdminRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName admin,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Middleware
fetchAdminMiddleware = forall user normalizedModel.
(normalizedModel ~ NormalizeModel user, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Key (Maybe UUID) -> Key (Maybe normalizedModel) -> Middleware
fetchUserMiddlewareFor @admin Key (Maybe UUID)
currentAdminIdVaultKey Key (Maybe normalizedModel)
Key (Maybe CurrentAdminRecord)
currentAdminVaultKey
{-# INLINE fetchAdminMiddleware #-}

-- | Building block: reads a UUID from the given ID vault key, fetches the
-- record from the database, and stores it in the given user vault key.
fetchUserMiddlewareFor :: forall user normalizedModel.
    ( normalizedModel ~ NormalizeModel user
    , Typeable normalizedModel
    , Table normalizedModel
    , FromRowHasql normalizedModel
    , PrimaryKey (GetTableName normalizedModel) ~ UUID
    , GetTableName normalizedModel ~ GetTableName user
    , FilterPrimaryKey (GetTableName normalizedModel)
    ) => Vault.Key (Maybe UUID) -> Vault.Key (Maybe normalizedModel) -> Wai.Middleware
fetchUserMiddlewareFor :: forall user normalizedModel.
(normalizedModel ~ NormalizeModel user, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Key (Maybe UUID) -> Key (Maybe normalizedModel) -> Middleware
fetchUserMiddlewareFor Key (Maybe UUID)
idKey Key (Maybe normalizedModel)
userKey Application
app Request
req Response -> IO ResponseReceived
respond = do
    let ?modelContext = Request
req.modelContext
    user <- case Key (Maybe UUID) -> Request -> Maybe UUID
forall user. Key (Maybe user) -> Request -> Maybe user
lookupAuthVault Key (Maybe UUID)
idKey Request
req of
        Just UUID
uuid -> Id' (GetTableName user) -> IO (Maybe normalizedModel)
forall fetchable model.
(Fetchable fetchable model, Table model, FromRowHasql model,
 ?modelContext::ModelContext) =>
fetchable -> IO (Maybe model)
fetchOneOrNothing (PrimaryKey (GetTableName user) -> Id' (GetTableName user)
forall (table :: Symbol). PrimaryKey table -> Id' table
Id UUID
PrimaryKey (GetTableName user)
uuid)
        Maybe UUID
Nothing -> Maybe normalizedModel -> IO (Maybe normalizedModel)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe normalizedModel
forall a. Maybe a
Nothing
    let req' = Request
req { Wai.vault = Vault.insert userKey user (Wai.vault req) }
    app req' respond
{-# INLINE fetchUserMiddlewareFor #-}

-- | Middleware that authenticates the current user and stores it in the request vault
-- using 'currentUserVaultKey'.
--
-- This is the standard middleware for user authentication. Add it to your Config.hs:
--
-- > import IHP.LoginSupport.Middleware
-- >
-- > config :: ConfigBuilder
-- > config = do
-- >     option $ AuthMiddleware (authMiddleware @User)
--
-- For both user and admin authentication:
--
-- > option $ AuthMiddleware (authMiddleware @User . adminAuthMiddleware @Admin)
--
-- This is equivalent to @userIdMiddleware (sessionKey \@User) . fetchUserMiddleware \@User@.
--
authMiddleware :: forall user normalizedModel.
    ( normalizedModel ~ NormalizeModel user
    , normalizedModel ~ CurrentUserRecord
    , Typeable normalizedModel
    , Table normalizedModel
    , FromRowHasql normalizedModel
    , PrimaryKey (GetTableName normalizedModel) ~ UUID
    , GetTableName normalizedModel ~ GetTableName user
    , FilterPrimaryKey (GetTableName normalizedModel)
    , KnownSymbol (GetModelName user)
    ) => Wai.Middleware
authMiddleware :: forall user normalizedModel.
(normalizedModel ~ NormalizeModel user,
 normalizedModel ~ CurrentUserRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel),
 KnownSymbol (GetModelName user)) =>
Middleware
authMiddleware = ByteString -> Middleware
userIdMiddleware (forall user. KnownSymbol (GetModelName user) => ByteString
sessionKey @user) Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall user normalizedModel.
(normalizedModel ~ NormalizeModel user,
 normalizedModel ~ CurrentUserRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName user,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Middleware
fetchUserMiddleware @user
{-# INLINE authMiddleware #-}

-- | Middleware that authenticates the current admin and stores it in the request vault
-- using 'currentAdminVaultKey'.
--
-- > option $ AuthMiddleware (authMiddleware @User . adminAuthMiddleware @Admin)
--
-- This is equivalent to @adminIdMiddleware (sessionKey \@Admin) . fetchAdminMiddleware \@Admin@.
--
adminAuthMiddleware :: forall admin normalizedModel.
    ( normalizedModel ~ NormalizeModel admin
    , normalizedModel ~ CurrentAdminRecord
    , Typeable normalizedModel
    , Table normalizedModel
    , FromRowHasql normalizedModel
    , PrimaryKey (GetTableName normalizedModel) ~ UUID
    , GetTableName normalizedModel ~ GetTableName admin
    , FilterPrimaryKey (GetTableName normalizedModel)
    , KnownSymbol (GetModelName admin)
    ) => Wai.Middleware
adminAuthMiddleware :: forall admin normalizedModel.
(normalizedModel ~ NormalizeModel admin,
 normalizedModel ~ CurrentAdminRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName admin,
 FilterPrimaryKey (GetTableName normalizedModel),
 KnownSymbol (GetModelName admin)) =>
Middleware
adminAuthMiddleware = ByteString -> Middleware
adminIdMiddleware (forall user. KnownSymbol (GetModelName user) => ByteString
sessionKey @admin) Middleware -> Middleware -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall admin normalizedModel.
(normalizedModel ~ NormalizeModel admin,
 normalizedModel ~ CurrentAdminRecord, Typeable normalizedModel,
 Table normalizedModel, FromRowHasql normalizedModel,
 PrimaryKey (GetTableName normalizedModel) ~ UUID,
 GetTableName normalizedModel ~ GetTableName admin,
 FilterPrimaryKey (GetTableName normalizedModel)) =>
Middleware
fetchAdminMiddleware @admin
{-# INLINE adminAuthMiddleware #-}

-- | Low-level building block: middleware that runs a fetch function and stores
-- the result in the request vault under the given key.
--
-- This decouples the vault insertion from the database lookup, making it
-- useful for testing and custom authentication schemes.
authMiddlewareWith :: Vault.Key (Maybe user) -> (Wai.Request -> IO (Maybe user)) -> Wai.Middleware
authMiddlewareWith :: forall user.
Key (Maybe user) -> (Request -> IO (Maybe user)) -> Middleware
authMiddlewareWith Key (Maybe user)
key Request -> IO (Maybe user)
fetchUser Application
app Request
req Response -> IO ResponseReceived
respond = do
    user <- Request -> IO (Maybe user)
fetchUser Request
req
    let req' = Request
req { Wai.vault = Vault.insert key user (Wai.vault req) }
    app req' respond
{-# INLINE authMiddlewareWith #-}