module IHP.DataSync.RowLevelSecurity
( ensureRLSEnabled
, hasRLSEnabled
, TableWithRLS (tableName)
, makeCachedEnsureRLSEnabled
, sqlQueryWithRLS
, sqlExecWithRLS
)
where

import IHP.ControllerPrelude
import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified Database.PostgreSQL.Simple.Types as PG
import qualified Database.PostgreSQL.Simple.ToRow as PG
import qualified IHP.DataSync.Role as Role
import qualified Data.Set as Set

sqlQueryWithRLS ::
    ( ?modelContext :: ModelContext
    , PG.ToRow parameters
    , ?context :: ControllerContext
    , userId ~ Id CurrentUserRecord
    , Show (PrimaryKey (GetTableName CurrentUserRecord))
    , HasNewSessionUrl CurrentUserRecord
    , Typeable CurrentUserRecord
    , ?context :: ControllerContext
    , HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
    , PG.ToField userId
    , FromRow result
    ) => PG.Query -> parameters -> IO [result]
sqlQueryWithRLS :: forall parameters userId result.
(?modelContext::ModelContext, ToRow parameters,
 ?context::ControllerContext, userId ~ Id CurrentUserRecord,
 Show (PrimaryKey (GetTableName CurrentUserRecord)),
 HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
 ?context::ControllerContext,
 HasField "id" CurrentUserRecord (Id CurrentUserRecord),
 ToField userId, FromRow result) =>
Query -> parameters -> IO [result]
sqlQueryWithRLS Query
query parameters
parameters = Query -> [Action] -> IO [result]
forall q r.
(?modelContext::ModelContext, ToRow q, FromRow r) =>
Query -> q -> IO [r]
sqlQuery Query
queryWithRLS [Action]
parametersWithRLS
    where
        (Query
queryWithRLS, [Action]
parametersWithRLS) = Query -> parameters -> (Query, [Action])
forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
 ?context::ControllerContext, userId ~ Id CurrentUserRecord,
 Show (PrimaryKey (GetTableName CurrentUserRecord)),
 HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
 ?context::ControllerContext,
 HasField "id" CurrentUserRecord (Id CurrentUserRecord),
 ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters
{-# INLINE sqlQueryWithRLS #-}

sqlExecWithRLS ::
    ( ?modelContext :: ModelContext
    , PG.ToRow parameters
    , ?context :: ControllerContext
    , userId ~ Id CurrentUserRecord
    , Show (PrimaryKey (GetTableName CurrentUserRecord))
    , HasNewSessionUrl CurrentUserRecord
    , Typeable CurrentUserRecord
    , ?context :: ControllerContext
    , HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
    , PG.ToField userId
    ) => PG.Query -> parameters -> IO Int64
sqlExecWithRLS :: forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
 ?context::ControllerContext, userId ~ Id CurrentUserRecord,
 Show (PrimaryKey (GetTableName CurrentUserRecord)),
 HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
 ?context::ControllerContext,
 HasField "id" CurrentUserRecord (Id CurrentUserRecord),
 ToField userId) =>
Query -> parameters -> IO Int64
sqlExecWithRLS Query
query parameters
parameters = Query -> [Action] -> IO Int64
forall q.
(?modelContext::ModelContext, ToRow q) =>
Query -> q -> IO Int64
sqlExec Query
queryWithRLS [Action]
parametersWithRLS
    where
        (Query
queryWithRLS, [Action]
parametersWithRLS) = Query -> parameters -> (Query, [Action])
forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
 ?context::ControllerContext, userId ~ Id CurrentUserRecord,
 Show (PrimaryKey (GetTableName CurrentUserRecord)),
 HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
 ?context::ControllerContext,
 HasField "id" CurrentUserRecord (Id CurrentUserRecord),
 ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters
{-# INLINE sqlExecWithRLS #-}

wrapStatementWithRLS ::
    ( ?modelContext :: ModelContext
    , PG.ToRow parameters
    , ?context :: ControllerContext
    , userId ~ Id CurrentUserRecord
    , Show (PrimaryKey (GetTableName CurrentUserRecord))
    , HasNewSessionUrl CurrentUserRecord
    , Typeable CurrentUserRecord
    , ?context :: ControllerContext
    , HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
    , PG.ToField userId
    ) => PG.Query -> parameters -> (PG.Query, [PG.Action])
wrapStatementWithRLS :: forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
 ?context::ControllerContext, userId ~ Id CurrentUserRecord,
 Show (PrimaryKey (GetTableName CurrentUserRecord)),
 HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
 ?context::ControllerContext,
 HasField "id" CurrentUserRecord (Id CurrentUserRecord),
 ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters = (Query
queryWithRLS, [Action]
parametersWithRLS)
    where
        queryWithRLS :: Query
queryWithRLS = Query
"SET LOCAL ROLE ?; SET LOCAL rls.ihp_user_id = ?; " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
query Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
";"

        maybeUserId :: Maybe (Id CurrentUserRecord)
maybeUserId = (.id) (CurrentUserRecord -> Id CurrentUserRecord)
-> Maybe CurrentUserRecord -> Maybe (Id CurrentUserRecord)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe CurrentUserRecord
forall user.
(?context::ControllerContext, HasNewSessionUrl user, Typeable user,
 user ~ CurrentUserRecord) =>
Maybe user
currentUserOrNothing

        -- When the user is not logged in and maybeUserId is Nothing, we cannot
        -- just pass @NULL@ to postgres. The @SET LOCAL@ values can only be strings.
        --
        -- Therefore we map Nothing to an empty string here. The empty string
        -- means "not logged in".
        --
        encodedUserId :: Action
encodedUserId = case Maybe (Id CurrentUserRecord)
maybeUserId of
                Just Id CurrentUserRecord
userId -> Id CurrentUserRecord -> Action
forall a. ToField a => a -> Action
PG.toField Id CurrentUserRecord
userId
                Maybe (Id CurrentUserRecord)
Nothing -> Text -> Action
forall a. ToField a => a -> Action
PG.toField (Text
"" :: Text)

        parametersWithRLS :: [Action]
parametersWithRLS = [Identifier -> Action
forall a. ToField a => a -> Action
PG.toField (Text -> Identifier
PG.Identifier Text
forall context. (?context::context, ConfigProvider context) => Text
Role.authenticatedRole), Action -> Action
forall a. ToField a => a -> Action
PG.toField Action
encodedUserId] [Action] -> [Action] -> [Action]
forall a. Semigroup a => a -> a -> a
<> (parameters -> [Action]
forall a. ToRow a => a -> [Action]
PG.toRow parameters
parameters)
{-# INLINE wrapStatementWithRLS #-}

-- | Returns a proof that RLS is enabled for a table
ensureRLSEnabled :: (?modelContext :: ModelContext) => Text -> IO TableWithRLS
ensureRLSEnabled :: (?modelContext::ModelContext) => Text -> IO TableWithRLS
ensureRLSEnabled Text
table = do
    Bool
rlsEnabled <- (?modelContext::ModelContext) => Text -> IO Bool
Text -> IO Bool
hasRLSEnabled Text
table
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
rlsEnabled (Text -> IO ()
forall a. Text -> a
error Text
"Row level security is required for accessing this table")
    TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> TableWithRLS
TableWithRLS Text
table)

-- | Returns a factory for 'ensureRLSEnabled' that memoizes when a table has RLS enabled.
--
-- When a table doesn't have RLS enabled yet, the result is not memoized.
--
-- __Example:__
--
-- > -- Setup
-- > ensureRLSEnabled <- makeCachedEnsureRLSEnabled
-- >
-- > ensureRLSEnabled "projects" -- Runs a database query to check if row level security is enabled for the projects table
-- >
-- > -- Asuming 'ensureRLSEnabled "projects"' proceeded without errors:
-- >
-- > ensureRLSEnabled "projects" -- Now this will instantly return True and don't fire any SQL queries anymore
--
makeCachedEnsureRLSEnabled :: (?modelContext :: ModelContext) => IO (Text -> IO TableWithRLS)
makeCachedEnsureRLSEnabled :: (?modelContext::ModelContext) => IO (Text -> IO TableWithRLS)
makeCachedEnsureRLSEnabled = do
    IORef (Set Text)
tables <- Set Text -> IO (IORef (Set Text))
forall a. a -> IO (IORef a)
newIORef Set Text
forall a. Set a
Set.empty
    (Text -> IO TableWithRLS) -> IO (Text -> IO TableWithRLS)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \Text
tableName -> do
        Bool
rlsEnabled <- Text -> Set Text -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Text
tableName (Set Text -> Bool) -> IO (Set Text) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Set Text) -> IO (Set Text)
forall a. IORef a -> IO a
readIORef IORef (Set Text)
tables

        if Bool
rlsEnabled
            then TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TableWithRLS { Text
tableName :: Text
tableName :: Text
tableName }
            else do
                TableWithRLS
proof <- (?modelContext::ModelContext) => Text -> IO TableWithRLS
Text -> IO TableWithRLS
ensureRLSEnabled Text
tableName
                IORef (Set Text) -> (Set Text -> Set Text) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Set Text)
tables (Text -> Set Text -> Set Text
forall a. Ord a => a -> Set a -> Set a
Set.insert Text
tableName)
                TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TableWithRLS
proof

-- | Returns 'True' if row level security has been enabled on a table
--
-- RLS can be enabled with this SQL statement:
--
-- > ALTER TABLE my_table ENABLE ROW LEVEL SECURITY;
--
-- After this 'hasRLSEnabled' will return true:
--
-- >>> hasRLSEnabled "my_table"
-- True
hasRLSEnabled :: (?modelContext :: ModelContext) => Text -> IO Bool
hasRLSEnabled :: (?modelContext::ModelContext) => Text -> IO Bool
hasRLSEnabled Text
table = Query -> [Text] -> IO Bool
forall q value.
(?modelContext::ModelContext, ToRow q, FromField value) =>
Query -> q -> IO value
sqlQueryScalar Query
"SELECT relrowsecurity FROM pg_class WHERE oid = quote_ident(?)::regclass" [Text
table]

-- | Can be constructed using 'ensureRLSEnabled'
--
-- > tableWithRLS <- ensureRLSEnabled "my_table"
--
-- Useful to carry a proof that the RLS is actually enabled
newtype TableWithRLS = TableWithRLS { TableWithRLS -> Text
tableName :: Text } deriving (TableWithRLS -> TableWithRLS -> Bool
(TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool) -> Eq TableWithRLS
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TableWithRLS -> TableWithRLS -> Bool
== :: TableWithRLS -> TableWithRLS -> Bool
$c/= :: TableWithRLS -> TableWithRLS -> Bool
/= :: TableWithRLS -> TableWithRLS -> Bool
Eq, Eq TableWithRLS
Eq TableWithRLS =>
(TableWithRLS -> TableWithRLS -> Ordering)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> TableWithRLS)
-> (TableWithRLS -> TableWithRLS -> TableWithRLS)
-> Ord TableWithRLS
TableWithRLS -> TableWithRLS -> Bool
TableWithRLS -> TableWithRLS -> Ordering
TableWithRLS -> TableWithRLS -> TableWithRLS
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: TableWithRLS -> TableWithRLS -> Ordering
compare :: TableWithRLS -> TableWithRLS -> Ordering
$c< :: TableWithRLS -> TableWithRLS -> Bool
< :: TableWithRLS -> TableWithRLS -> Bool
$c<= :: TableWithRLS -> TableWithRLS -> Bool
<= :: TableWithRLS -> TableWithRLS -> Bool
$c> :: TableWithRLS -> TableWithRLS -> Bool
> :: TableWithRLS -> TableWithRLS -> Bool
$c>= :: TableWithRLS -> TableWithRLS -> Bool
>= :: TableWithRLS -> TableWithRLS -> Bool
$cmax :: TableWithRLS -> TableWithRLS -> TableWithRLS
max :: TableWithRLS -> TableWithRLS -> TableWithRLS
$cmin :: TableWithRLS -> TableWithRLS -> TableWithRLS
min :: TableWithRLS -> TableWithRLS -> TableWithRLS
Ord)