module IHP.DataSync.RowLevelSecurity
( ensureRLSEnabled
, hasRLSEnabled
, TableWithRLS (tableName)
, makeCachedEnsureRLSEnabled
, sqlQueryWithRLS
, sqlExecWithRLS
)
where
import IHP.ControllerPrelude
import qualified Database.PostgreSQL.Simple as PG
import qualified Database.PostgreSQL.Simple.ToField as PG
import qualified Database.PostgreSQL.Simple.Types as PG
import qualified Database.PostgreSQL.Simple.ToRow as PG
import qualified IHP.DataSync.Role as Role
import qualified Data.Set as Set
sqlQueryWithRLS ::
( ?modelContext :: ModelContext
, PG.ToRow parameters
, ?context :: ControllerContext
, userId ~ Id CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
, HasNewSessionUrl CurrentUserRecord
, Typeable CurrentUserRecord
, ?context :: ControllerContext
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
, PG.ToField userId
, FromRow result
) => PG.Query -> parameters -> IO [result]
sqlQueryWithRLS :: forall parameters userId result.
(?modelContext::ModelContext, ToRow parameters,
?context::ControllerContext, userId ~ Id CurrentUserRecord,
Show (PrimaryKey (GetTableName CurrentUserRecord)),
HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
?context::ControllerContext,
HasField "id" CurrentUserRecord (Id CurrentUserRecord),
ToField userId, FromRow result) =>
Query -> parameters -> IO [result]
sqlQueryWithRLS Query
query parameters
parameters = Query -> [Action] -> IO [result]
forall q r.
(?modelContext::ModelContext, ToRow q, FromRow r) =>
Query -> q -> IO [r]
sqlQuery Query
queryWithRLS [Action]
parametersWithRLS
where
(Query
queryWithRLS, [Action]
parametersWithRLS) = Query -> parameters -> (Query, [Action])
forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
?context::ControllerContext, userId ~ Id CurrentUserRecord,
Show (PrimaryKey (GetTableName CurrentUserRecord)),
HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
?context::ControllerContext,
HasField "id" CurrentUserRecord (Id CurrentUserRecord),
ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters
{-# INLINE sqlQueryWithRLS #-}
sqlExecWithRLS ::
( ?modelContext :: ModelContext
, PG.ToRow parameters
, ?context :: ControllerContext
, userId ~ Id CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
, HasNewSessionUrl CurrentUserRecord
, Typeable CurrentUserRecord
, ?context :: ControllerContext
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
, PG.ToField userId
) => PG.Query -> parameters -> IO Int64
sqlExecWithRLS :: forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
?context::ControllerContext, userId ~ Id CurrentUserRecord,
Show (PrimaryKey (GetTableName CurrentUserRecord)),
HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
?context::ControllerContext,
HasField "id" CurrentUserRecord (Id CurrentUserRecord),
ToField userId) =>
Query -> parameters -> IO Int64
sqlExecWithRLS Query
query parameters
parameters = Query -> [Action] -> IO Int64
forall q.
(?modelContext::ModelContext, ToRow q) =>
Query -> q -> IO Int64
sqlExec Query
queryWithRLS [Action]
parametersWithRLS
where
(Query
queryWithRLS, [Action]
parametersWithRLS) = Query -> parameters -> (Query, [Action])
forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
?context::ControllerContext, userId ~ Id CurrentUserRecord,
Show (PrimaryKey (GetTableName CurrentUserRecord)),
HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
?context::ControllerContext,
HasField "id" CurrentUserRecord (Id CurrentUserRecord),
ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters
{-# INLINE sqlExecWithRLS #-}
wrapStatementWithRLS ::
( ?modelContext :: ModelContext
, PG.ToRow parameters
, ?context :: ControllerContext
, userId ~ Id CurrentUserRecord
, Show (PrimaryKey (GetTableName CurrentUserRecord))
, HasNewSessionUrl CurrentUserRecord
, Typeable CurrentUserRecord
, ?context :: ControllerContext
, HasField "id" CurrentUserRecord (Id' (GetTableName CurrentUserRecord))
, PG.ToField userId
) => PG.Query -> parameters -> (PG.Query, [PG.Action])
wrapStatementWithRLS :: forall parameters userId.
(?modelContext::ModelContext, ToRow parameters,
?context::ControllerContext, userId ~ Id CurrentUserRecord,
Show (PrimaryKey (GetTableName CurrentUserRecord)),
HasNewSessionUrl CurrentUserRecord, Typeable CurrentUserRecord,
?context::ControllerContext,
HasField "id" CurrentUserRecord (Id CurrentUserRecord),
ToField userId) =>
Query -> parameters -> (Query, [Action])
wrapStatementWithRLS Query
query parameters
parameters = (Query
queryWithRLS, [Action]
parametersWithRLS)
where
queryWithRLS :: Query
queryWithRLS = Query
"SET LOCAL ROLE ?; SET LOCAL rls.ihp_user_id = ?; " Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
query Query -> Query -> Query
forall a. Semigroup a => a -> a -> a
<> Query
";"
maybeUserId :: Maybe (Id CurrentUserRecord)
maybeUserId = (.id) (CurrentUserRecord -> Id CurrentUserRecord)
-> Maybe CurrentUserRecord -> Maybe (Id CurrentUserRecord)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe CurrentUserRecord
forall user.
(?context::ControllerContext, HasNewSessionUrl user, Typeable user,
user ~ CurrentUserRecord) =>
Maybe user
currentUserOrNothing
encodedUserId :: Action
encodedUserId = case Maybe (Id CurrentUserRecord)
maybeUserId of
Just Id CurrentUserRecord
userId -> Id CurrentUserRecord -> Action
forall a. ToField a => a -> Action
PG.toField Id CurrentUserRecord
userId
Maybe (Id CurrentUserRecord)
Nothing -> Text -> Action
forall a. ToField a => a -> Action
PG.toField (Text
"" :: Text)
parametersWithRLS :: [Action]
parametersWithRLS = [Identifier -> Action
forall a. ToField a => a -> Action
PG.toField (Text -> Identifier
PG.Identifier Text
forall context. (?context::context, ConfigProvider context) => Text
Role.authenticatedRole), Action -> Action
forall a. ToField a => a -> Action
PG.toField Action
encodedUserId] [Action] -> [Action] -> [Action]
forall a. Semigroup a => a -> a -> a
<> (parameters -> [Action]
forall a. ToRow a => a -> [Action]
PG.toRow parameters
parameters)
{-# INLINE wrapStatementWithRLS #-}
ensureRLSEnabled :: (?modelContext :: ModelContext) => Text -> IO TableWithRLS
ensureRLSEnabled :: (?modelContext::ModelContext) => Text -> IO TableWithRLS
ensureRLSEnabled Text
table = do
Bool
rlsEnabled <- (?modelContext::ModelContext) => Text -> IO Bool
Text -> IO Bool
hasRLSEnabled Text
table
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
rlsEnabled (Text -> IO ()
forall a. Text -> a
error Text
"Row level security is required for accessing this table")
TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> TableWithRLS
TableWithRLS Text
table)
makeCachedEnsureRLSEnabled :: (?modelContext :: ModelContext) => IO (Text -> IO TableWithRLS)
makeCachedEnsureRLSEnabled :: (?modelContext::ModelContext) => IO (Text -> IO TableWithRLS)
makeCachedEnsureRLSEnabled = do
IORef (Set Text)
tables <- Set Text -> IO (IORef (Set Text))
forall a. a -> IO (IORef a)
newIORef Set Text
forall a. Set a
Set.empty
(Text -> IO TableWithRLS) -> IO (Text -> IO TableWithRLS)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure \Text
tableName -> do
Bool
rlsEnabled <- Text -> Set Text -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Text
tableName (Set Text -> Bool) -> IO (Set Text) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Set Text) -> IO (Set Text)
forall a. IORef a -> IO a
readIORef IORef (Set Text)
tables
if Bool
rlsEnabled
then TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TableWithRLS { Text
tableName :: Text
tableName :: Text
tableName }
else do
TableWithRLS
proof <- (?modelContext::ModelContext) => Text -> IO TableWithRLS
Text -> IO TableWithRLS
ensureRLSEnabled Text
tableName
IORef (Set Text) -> (Set Text -> Set Text) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Set Text)
tables (Text -> Set Text -> Set Text
forall a. Ord a => a -> Set a -> Set a
Set.insert Text
tableName)
TableWithRLS -> IO TableWithRLS
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TableWithRLS
proof
hasRLSEnabled :: (?modelContext :: ModelContext) => Text -> IO Bool
hasRLSEnabled :: (?modelContext::ModelContext) => Text -> IO Bool
hasRLSEnabled Text
table = Query -> [Text] -> IO Bool
forall q value.
(?modelContext::ModelContext, ToRow q, FromField value) =>
Query -> q -> IO value
sqlQueryScalar Query
"SELECT relrowsecurity FROM pg_class WHERE oid = quote_ident(?)::regclass" [Text
table]
newtype TableWithRLS = TableWithRLS { TableWithRLS -> Text
tableName :: Text } deriving (TableWithRLS -> TableWithRLS -> Bool
(TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool) -> Eq TableWithRLS
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TableWithRLS -> TableWithRLS -> Bool
== :: TableWithRLS -> TableWithRLS -> Bool
$c/= :: TableWithRLS -> TableWithRLS -> Bool
/= :: TableWithRLS -> TableWithRLS -> Bool
Eq, Eq TableWithRLS
Eq TableWithRLS =>
(TableWithRLS -> TableWithRLS -> Ordering)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> Bool)
-> (TableWithRLS -> TableWithRLS -> TableWithRLS)
-> (TableWithRLS -> TableWithRLS -> TableWithRLS)
-> Ord TableWithRLS
TableWithRLS -> TableWithRLS -> Bool
TableWithRLS -> TableWithRLS -> Ordering
TableWithRLS -> TableWithRLS -> TableWithRLS
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: TableWithRLS -> TableWithRLS -> Ordering
compare :: TableWithRLS -> TableWithRLS -> Ordering
$c< :: TableWithRLS -> TableWithRLS -> Bool
< :: TableWithRLS -> TableWithRLS -> Bool
$c<= :: TableWithRLS -> TableWithRLS -> Bool
<= :: TableWithRLS -> TableWithRLS -> Bool
$c> :: TableWithRLS -> TableWithRLS -> Bool
> :: TableWithRLS -> TableWithRLS -> Bool
$c>= :: TableWithRLS -> TableWithRLS -> Bool
>= :: TableWithRLS -> TableWithRLS -> Bool
$cmax :: TableWithRLS -> TableWithRLS -> TableWithRLS
max :: TableWithRLS -> TableWithRLS -> TableWithRLS
$cmin :: TableWithRLS -> TableWithRLS -> TableWithRLS
min :: TableWithRLS -> TableWithRLS -> TableWithRLS
Ord)