{-# LANGUAGE MultiParamTypeClasses, TypeFamilies, FlexibleContexts, AllowAmbiguousTypes, UndecidableInstances, FlexibleInstances, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, ConstraintKinds, TypeOperators, GADTs, GeneralizedNewtypeDeriving #-}

{-|
Module: IHP.ModelSupport.Types
Description: Core types for IHP's model and database support
Copyright: (c) digitally induced GmbH, 2020

This module contains the core types for IHP's model support.
It's designed to be lightweight and avoid heavy dependencies,
allowing modules that only need the types to compile faster.

For the full model API including query functions, use 'IHP.ModelSupport'.
-}
module IHP.ModelSupport.Types
( -- * Model Context
  ModelContext (..)
, RowLevelSecurityContext (..)
, TransactionRunner (..)
  -- * Type Families
, GetModelById
, GetTableName
, GetModelByTableName
, PrimaryKey
, GetModelName
, Include
, Include'
, NormalizeModel
  -- * Id Types
, Id'(..)
, Id
  -- * Record Metadata
, MetaBag (..)
, Violation (..)
, FieldName
  -- * Field Wrappers
, FieldWithDefault (..)
, FieldWithUpdate (..)
  -- * Utility Types
, LabeledData (..)
  -- * Exceptions
, RecordNotFoundException (..)
, EnhancedSqlError (..)
, enhancedSqlErrorMessage
, HasqlSessionError (..)
  -- * Type Classes
, CanCreate (..)
, CanUpdate (..)
, ParsePrimaryKey (..)
) where

import Prelude
import Data.ByteString (ByteString)
import Data.Text (Text)
import qualified Data.Text.Encoding
import Data.Hashable (Hashable)
import Control.DeepSeq (NFData)
import Control.Exception (Exception)
import Database.PostgreSQL.Simple.Types (Query)
import qualified Database.PostgreSQL.Simple as PG
import qualified Hasql.Pool as Hasql
import qualified Hasql.Session as HasqlSession
import qualified Hasql.Errors as HasqlErrors
import GHC.TypeLits
import GHC.Types
import Data.Data
import Data.Dynamic
import IHP.Log.Types (Logger)

-- | Runner that executes a hasql Session on the current transaction's connection
newtype TransactionRunner = TransactionRunner
    { TransactionRunner -> forall a. Session a -> IO a
runInTransaction :: forall a. HasqlSession.Session a -> IO a }

-- | Wrapper to make 'HasqlErrors.SessionError' an 'Exception', since it doesn't have one by default
data HasqlSessionError = HasqlSessionError HasqlErrors.SessionError
    deriving (Int -> HasqlSessionError -> ShowS
[HasqlSessionError] -> ShowS
HasqlSessionError -> String
(Int -> HasqlSessionError -> ShowS)
-> (HasqlSessionError -> String)
-> ([HasqlSessionError] -> ShowS)
-> Show HasqlSessionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HasqlSessionError -> ShowS
showsPrec :: Int -> HasqlSessionError -> ShowS
$cshow :: HasqlSessionError -> String
show :: HasqlSessionError -> String
$cshowList :: [HasqlSessionError] -> ShowS
showList :: [HasqlSessionError] -> ShowS
Show)

instance Exception HasqlSessionError

-- | Provides the db connection and some IHP-specific db configuration
data ModelContext = ModelContext
    { ModelContext -> Pool
hasqlPool :: Hasql.Pool -- ^ Hasql pool for prepared statement-based queries
    , ModelContext -> Maybe TransactionRunner
transactionRunner :: Maybe TransactionRunner -- ^ When set, queries are sent through this runner instead of 'HasqlPool.use' directly
    -- | Logs all queries to this logger at log level info
    , ModelContext -> Logger
logger :: Logger
    -- | A callback that is called whenever a specific table is accessed using a SELECT query
    , ModelContext -> Maybe (Text -> IO ())
trackTableReadCallback :: Maybe (Text -> IO ())
    -- | Is set to a value if row level security was enabled at runtime
    , ModelContext -> Maybe RowLevelSecurityContext
rowLevelSecurity :: Maybe RowLevelSecurityContext
    }

-- | When row level security is enabled at runtime, this keeps track of the current
-- logged in user and the postgresql role to switch to.
data RowLevelSecurityContext = RowLevelSecurityContext
    { RowLevelSecurityContext -> Text
rlsAuthenticatedRole :: Text -- ^ Default is @ihp_authenticated@. This value comes from the @IHP_RLS_AUTHENTICATED_ROLE@  env var.
    , RowLevelSecurityContext -> Text
rlsUserId :: Text -- ^ The user id of the current logged in user
    }

type family GetModelById id :: Type where
    GetModelById (Maybe (Id' tableName)) = Maybe (GetModelByTableName tableName)
    GetModelById (Id' tableName) = GetModelByTableName tableName

type family GetTableName model :: Symbol
type family GetModelByTableName (tableName :: Symbol) :: Type

-- | Provides the primary key type for a given table. The instances are usually declared
-- by the generated haskell code in Generated.Types
--
-- __Example:__ Defining the primary key for a users table
--
-- > type instance PrimaryKey "users" = UUID
--
--
-- __Example:__ Defining the primary key for a table with a SERIAL pk
--
-- > type instance PrimaryKey "projects" = Int
--
type family PrimaryKey (tableName :: Symbol)

type family GetModelName model :: Symbol

type family Include (name :: GHC.Types.Symbol) model

type family Include' (name :: [GHC.Types.Symbol]) model where
    Include' '[] model = model
    Include' (x:xs) model = Include' xs (Include x model)

-- | Helper type to deal with models where relations are included or that are only partially fetched
-- Examples:
--
-- >>> NormalizeModel (Include "author_id" Post)
-- Post
--
-- >>> NormalizeModel Post
-- Post
type NormalizeModel model = GetModelByTableName (GetTableName model)

newtype Id' table = Id (PrimaryKey table)

deriving instance (Eq (PrimaryKey table)) => Eq (Id' table)
deriving instance (Ord (PrimaryKey table)) => Ord (Id' table)
deriving instance (Hashable (PrimaryKey table)) => Hashable (Id' table)
deriving instance (KnownSymbol table, Data (PrimaryKey table)) => Data (Id' table)
deriving instance (KnownSymbol table, NFData (PrimaryKey table)) => NFData (Id' table)

-- | We need to map the model to its table name to prevent infinite recursion in the model data definition
-- E.g. `type Project = Project' { id :: Id Project }` will not work
-- But `type Project = Project' { id :: Id "projects" }` will
type Id model = Id' (GetTableName model)

type FieldName = ByteString

-- | The error message of a validator can be either a plain text value or a HTML formatted value
data Violation
    = TextViolation { Violation -> Text
message :: !Text } -- ^ Plain text validation error, like "cannot be empty"
    | HtmlViolation { message :: !Text } -- ^ HTML formatted, already pre-escaped validation error, like "Invalid, please <a href="http://example.com">check the documentation</a>"
    deriving (Violation -> Violation -> Bool
(Violation -> Violation -> Bool)
-> (Violation -> Violation -> Bool) -> Eq Violation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Violation -> Violation -> Bool
== :: Violation -> Violation -> Bool
$c/= :: Violation -> Violation -> Bool
/= :: Violation -> Violation -> Bool
Eq, Int -> Violation -> ShowS
[Violation] -> ShowS
Violation -> String
(Int -> Violation -> ShowS)
-> (Violation -> String)
-> ([Violation] -> ShowS)
-> Show Violation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Violation -> ShowS
showsPrec :: Int -> Violation -> ShowS
$cshow :: Violation -> String
show :: Violation -> String
$cshowList :: [Violation] -> ShowS
showList :: [Violation] -> ShowS
Show)

-- | Every IHP database record has a magic @meta@ field which keeps a @MetaBag@ inside. This data structure is used e.g. to keep track of the validation errors that happend.
data MetaBag = MetaBag
    { MetaBag -> [(Text, Violation)]
annotations            :: ![(Text, Violation)] -- ^ Stores validation failures, as a list of (field name, error) pairs. E.g. @annotations = [ ("name", TextViolation "cannot be empty") ]@
    , MetaBag -> [Text]
touchedFields          :: ![Text] -- ^ Whenever a 'set' is callled on a field, it will be marked as touched. Only touched fields are saved to the database when you call 'updateRecord'
    , MetaBag -> Maybe Dynamic
originalDatabaseRecord :: Maybe Dynamic -- ^ When the record has been fetched from the database, we save the initial database record here. This is used by 'didChange' to check if a field value is different from the initial database value.
    } deriving (Int -> MetaBag -> ShowS
[MetaBag] -> ShowS
MetaBag -> String
(Int -> MetaBag -> ShowS)
-> (MetaBag -> String) -> ([MetaBag] -> ShowS) -> Show MetaBag
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MetaBag -> ShowS
showsPrec :: Int -> MetaBag -> ShowS
$cshow :: MetaBag -> String
show :: MetaBag -> String
$cshowList :: [MetaBag] -> ShowS
showList :: [MetaBag] -> ShowS
Show)

instance Eq MetaBag where
    MetaBag { [(Text, Violation)]
annotations :: MetaBag -> [(Text, Violation)]
annotations :: [(Text, Violation)]
annotations, [Text]
touchedFields :: MetaBag -> [Text]
touchedFields :: [Text]
touchedFields } == :: MetaBag -> MetaBag -> Bool
== MetaBag { annotations :: MetaBag -> [(Text, Violation)]
annotations = [(Text, Violation)]
annotations', touchedFields :: MetaBag -> [Text]
touchedFields = [Text]
touchedFields' } = [(Text, Violation)]
annotations [(Text, Violation)] -> [(Text, Violation)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Text, Violation)]
annotations' Bool -> Bool -> Bool
&& [Text]
touchedFields [Text] -> [Text] -> Bool
forall a. Eq a => a -> a -> Bool
== [Text]
touchedFields'

-- | Represents fields that have a default value in an SQL schema
--
--   The 'Default' constructor represents the default value from the schema,
--   while the 'NonDefault' constructor holds some other value for the field
data FieldWithDefault valueType = Default | NonDefault valueType deriving (FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
(FieldWithDefault valueType -> FieldWithDefault valueType -> Bool)
-> (FieldWithDefault valueType
    -> FieldWithDefault valueType -> Bool)
-> Eq (FieldWithDefault valueType)
forall valueType.
Eq valueType =>
FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall valueType.
Eq valueType =>
FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
== :: FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
$c/= :: forall valueType.
Eq valueType =>
FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
/= :: FieldWithDefault valueType -> FieldWithDefault valueType -> Bool
Eq, Int -> FieldWithDefault valueType -> ShowS
[FieldWithDefault valueType] -> ShowS
FieldWithDefault valueType -> String
(Int -> FieldWithDefault valueType -> ShowS)
-> (FieldWithDefault valueType -> String)
-> ([FieldWithDefault valueType] -> ShowS)
-> Show (FieldWithDefault valueType)
forall valueType.
Show valueType =>
Int -> FieldWithDefault valueType -> ShowS
forall valueType.
Show valueType =>
[FieldWithDefault valueType] -> ShowS
forall valueType.
Show valueType =>
FieldWithDefault valueType -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall valueType.
Show valueType =>
Int -> FieldWithDefault valueType -> ShowS
showsPrec :: Int -> FieldWithDefault valueType -> ShowS
$cshow :: forall valueType.
Show valueType =>
FieldWithDefault valueType -> String
show :: FieldWithDefault valueType -> String
$cshowList :: forall valueType.
Show valueType =>
[FieldWithDefault valueType] -> ShowS
showList :: [FieldWithDefault valueType] -> ShowS
Show)

-- | Represents fields that may have been updated
--
--   The 'NoUpdate' constructor represents the existing value in the database,
--   while the 'Update' constructor holds some new value for the field
data FieldWithUpdate name value
  = NoUpdate (Proxy name)
  | Update value
  deriving (FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
(FieldWithUpdate name value -> FieldWithUpdate name value -> Bool)
-> (FieldWithUpdate name value
    -> FieldWithUpdate name value -> Bool)
-> Eq (FieldWithUpdate name value)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (name :: k) value.
Eq value =>
FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
$c== :: forall k (name :: k) value.
Eq value =>
FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
== :: FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
$c/= :: forall k (name :: k) value.
Eq value =>
FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
/= :: FieldWithUpdate name value -> FieldWithUpdate name value -> Bool
Eq, Int -> FieldWithUpdate name value -> ShowS
[FieldWithUpdate name value] -> ShowS
FieldWithUpdate name value -> String
(Int -> FieldWithUpdate name value -> ShowS)
-> (FieldWithUpdate name value -> String)
-> ([FieldWithUpdate name value] -> ShowS)
-> Show (FieldWithUpdate name value)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (name :: k) value.
Show value =>
Int -> FieldWithUpdate name value -> ShowS
forall k (name :: k) value.
Show value =>
[FieldWithUpdate name value] -> ShowS
forall k (name :: k) value.
Show value =>
FieldWithUpdate name value -> String
$cshowsPrec :: forall k (name :: k) value.
Show value =>
Int -> FieldWithUpdate name value -> ShowS
showsPrec :: Int -> FieldWithUpdate name value -> ShowS
$cshow :: forall k (name :: k) value.
Show value =>
FieldWithUpdate name value -> String
show :: FieldWithUpdate name value -> String
$cshowList :: forall k (name :: k) value.
Show value =>
[FieldWithUpdate name value] -> ShowS
showList :: [FieldWithUpdate name value] -> ShowS
Show)

-- | Record type for objects of model types labeled with values from different database tables. (e.g. comments labeled with the IDs of the posts they belong to).
data LabeledData a b = LabeledData { forall a b. LabeledData a b -> a
labelValue :: a, forall a b. LabeledData a b -> b
contentValue :: b }
    deriving (Int -> LabeledData a b -> ShowS
[LabeledData a b] -> ShowS
LabeledData a b -> String
(Int -> LabeledData a b -> ShowS)
-> (LabeledData a b -> String)
-> ([LabeledData a b] -> ShowS)
-> Show (LabeledData a b)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a b. (Show a, Show b) => Int -> LabeledData a b -> ShowS
forall a b. (Show a, Show b) => [LabeledData a b] -> ShowS
forall a b. (Show a, Show b) => LabeledData a b -> String
$cshowsPrec :: forall a b. (Show a, Show b) => Int -> LabeledData a b -> ShowS
showsPrec :: Int -> LabeledData a b -> ShowS
$cshow :: forall a b. (Show a, Show b) => LabeledData a b -> String
show :: LabeledData a b -> String
$cshowList :: forall a b. (Show a, Show b) => [LabeledData a b] -> ShowS
showList :: [LabeledData a b] -> ShowS
Show)

-- | Thrown by 'fetchOne' when the query result is empty
data RecordNotFoundException
    = RecordNotFoundException { RecordNotFoundException -> Text
queryAndParams :: Text }
    deriving (Int -> RecordNotFoundException -> ShowS
[RecordNotFoundException] -> ShowS
RecordNotFoundException -> String
(Int -> RecordNotFoundException -> ShowS)
-> (RecordNotFoundException -> String)
-> ([RecordNotFoundException] -> ShowS)
-> Show RecordNotFoundException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RecordNotFoundException -> ShowS
showsPrec :: Int -> RecordNotFoundException -> ShowS
$cshow :: RecordNotFoundException -> String
show :: RecordNotFoundException -> String
$cshowList :: [RecordNotFoundException] -> ShowS
showList :: [RecordNotFoundException] -> ShowS
Show)

instance Exception RecordNotFoundException

-- | Whenever calls to 'Database.PostgreSQL.Simple.query' or 'Database.PostgreSQL.Simple.execute'
-- raise an 'Database.PostgreSQL.Simple.SqlError' exception, we wrap that exception in this data structure.
-- This allows us to show the actual database query that has triggered the error.
data EnhancedSqlError
    = EnhancedSqlError
    { EnhancedSqlError -> Query
sqlErrorQuery :: Query
    , EnhancedSqlError -> Text
sqlErrorQueryParams :: Text
    , EnhancedSqlError -> SqlError
sqlError :: PG.SqlError
    } deriving (Int -> EnhancedSqlError -> ShowS
[EnhancedSqlError] -> ShowS
EnhancedSqlError -> String
(Int -> EnhancedSqlError -> ShowS)
-> (EnhancedSqlError -> String)
-> ([EnhancedSqlError] -> ShowS)
-> Show EnhancedSqlError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EnhancedSqlError -> ShowS
showsPrec :: Int -> EnhancedSqlError -> ShowS
$cshow :: EnhancedSqlError -> String
show :: EnhancedSqlError -> String
$cshowList :: [EnhancedSqlError] -> ShowS
showList :: [EnhancedSqlError] -> ShowS
Show)

instance Exception EnhancedSqlError

-- | Extract the SQL error message as Text from an EnhancedSqlError.
--
-- This avoids downstream packages needing to import postgresql-simple
-- to access the 'sqlErrorMsg' field on 'PG.SqlError'.
enhancedSqlErrorMessage :: EnhancedSqlError -> Text
enhancedSqlErrorMessage :: EnhancedSqlError -> Text
enhancedSqlErrorMessage EnhancedSqlError
e = ByteString -> Text
Data.Text.Encoding.decodeUtf8 EnhancedSqlError
e.sqlError.sqlErrorMsg
{-# INLINE enhancedSqlErrorMessage #-}

class CanCreate a where
    create :: (?modelContext :: ModelContext) => a -> IO a
    createMany :: (?modelContext :: ModelContext) => [a] -> IO [a]

    -- | Like 'createRecord' but doesn't return the created record
    createRecordDiscardResult :: (?modelContext :: ModelContext) => a -> IO ()
    createRecordDiscardResult a
record = do
        _ <- a -> IO a
forall a. (CanCreate a, ?modelContext::ModelContext) => a -> IO a
create a
record
        pure ()

class CanUpdate a where
    updateRecord :: (?modelContext :: ModelContext) => a -> IO a

    -- | Like 'updateRecord' but doesn't return the updated record
    updateRecordDiscardResult :: (?modelContext :: ModelContext) => a -> IO ()
    updateRecordDiscardResult a
record = do
        _ <- a -> IO a
forall a. (CanUpdate a, ?modelContext::ModelContext) => a -> IO a
updateRecord a
record
        pure ()

class ParsePrimaryKey primaryKey where
    parsePrimaryKey :: Text -> Maybe primaryKey