{-# LANGUAGE BangPatterns, TypeFamilies, DataKinds, PolyKinds, TypeApplications, ScopedTypeVariables, ConstraintKinds, TypeOperators, GADTs, UndecidableInstances, StandaloneDeriving, FunctionalDependencies, FlexibleContexts, InstanceSigs, AllowAmbiguousTypes, DeriveAnyClass #-}

{-|
Module: IHP.QueryBuilder.HasqlCompiler
Description: Compile QueryBuilder to Hasql Statement
Copyright: (c) digitally induced GmbH, 2025

This module compiles QueryBuilder queries directly to Hasql 'Statement' values
by threading a parameter counter and encoder accumulator through compilation.
-}
module IHP.QueryBuilder.HasqlCompiler
( buildStatement
, buildWrappedStatement
, toSQL
, compileOperator
, CompilerState(..)
, emptyCompilerState
, nextParam
) where

import IHP.Prelude
import qualified Hasql.Encoders as Encoders
import qualified Hasql.Decoders as Decoders
import qualified Hasql.Statement as Hasql
import Data.Functor.Contravariant (contramap)
import Data.Functor.Contravariant.Divisible (conquer)
import IHP.QueryBuilder.Types
import IHP.QueryBuilder.Compiler (buildQuery)
import qualified Data.List as List

-- | Compile context: parameter counter + accumulated encoder.
data CompilerState = CompilerState !Int !(Encoders.Params ())

-- | Initial compile context: counter starts at 1, no params.
emptyCompilerState :: CompilerState
emptyCompilerState :: CompilerState
emptyCompilerState = Int -> Params () -> CompilerState
CompilerState Int
1 Params ()
forall a. Params a
forall (f :: * -> *) a. Divisible f => f a
conquer
{-# INLINE emptyCompilerState #-}

-- | Assign the next @$N@ placeholder and accumulate the encoder.
nextParam :: Encoders.Params () -> CompilerState -> (Text, CompilerState)
nextParam :: Params () -> CompilerState -> (Text, CompilerState)
nextParam Params ()
enc (CompilerState Int
n Params ()
acc) = (Text
"$" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
n, Int -> Params () -> CompilerState
CompilerState (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Params ()
acc Params () -> Params () -> Params ()
forall a. Semigroup a => a -> a -> a
<> Params ()
enc))
{-# INLINE nextParam #-}

-- | Build a Hasql 'Statement' from a compiled 'SQLQuery' and a result decoder.
buildStatement :: SQLQuery -> Decoders.Result a -> Hasql.Statement () a
buildStatement :: forall a. SQLQuery -> Result a -> Statement () a
buildStatement SQLQuery
sqlQuery Result a
decoder =
    let (Text
sql, CompilerState Int
_ Params ()
encoder) = CompilerState -> SQLQuery -> (Text, CompilerState)
compileQuery CompilerState
emptyCompilerState SQLQuery
sqlQuery
    in Text -> Params () -> Result a -> Statement () a
forall params result.
Text -> Params params -> Result result -> Statement params result
Hasql.preparable Text
sql Params ()
encoder Result a
decoder

-- | Like 'buildStatement', but wraps the compiled SQL with a prefix and suffix.
-- Used for @SELECT COUNT(*) FROM (inner) AS alias@ patterns.
buildWrappedStatement :: Text -> SQLQuery -> Text -> Decoders.Result a -> Hasql.Statement () a
buildWrappedStatement :: forall a. Text -> SQLQuery -> Text -> Result a -> Statement () a
buildWrappedStatement Text
prefix SQLQuery
sqlQuery Text
suffix Result a
decoder =
    let (Text
innerSql, CompilerState Int
_ Params ()
encoder) = CompilerState -> SQLQuery -> (Text, CompilerState)
compileQuery CompilerState
emptyCompilerState SQLQuery
sqlQuery
    in Text -> Params () -> Result a -> Statement () a
forall params result.
Text -> Params params -> Result result -> Statement params result
Hasql.preparable (Text
prefix Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
innerSql Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
suffix) Params ()
encoder Result a
decoder

-- | Compile a QueryBuilder to SQL text (for testing / error messages).
-- Discards the encoder.
toSQL :: forall table. KnownSymbol table => QueryBuilder table -> Text
toSQL :: forall (table :: Symbol).
KnownSymbol table =>
QueryBuilder table -> Text
toSQL QueryBuilder table
queryBuilder =
    let (Text
sql, CompilerState
_) = CompilerState -> SQLQuery -> (Text, CompilerState)
compileQuery CompilerState
emptyCompilerState (QueryBuilder table -> SQLQuery
forall (table :: Symbol).
KnownSymbol table =>
QueryBuilder table -> SQLQuery
buildQuery QueryBuilder table
queryBuilder)
    in Text
sql

-- | Compile a full SQLQuery to SQL text + updated compile context.
--
-- Structured so that the Nothing/empty branches contribute no concatenation;
-- GHC can see through the case alternatives and eliminate dead appends.
compileQuery :: CompilerState -> SQLQuery -> (Text, CompilerState)
compileQuery :: CompilerState -> SQLQuery -> (Text, CompilerState)
compileQuery CompilerState
cc0 SQLQuery { Text
selectFrom :: Text
selectFrom :: SQLQuery -> Text
selectFrom, Bool
distinctClause :: Bool
distinctClause :: SQLQuery -> Bool
distinctClause, Maybe Text
distinctOnClause :: Maybe Text
distinctOnClause :: SQLQuery -> Maybe Text
distinctOnClause, Maybe Condition
whereCondition :: Maybe Condition
whereCondition :: SQLQuery -> Maybe Condition
whereCondition, [OrderByClause]
orderByClause :: [OrderByClause]
orderByClause :: SQLQuery -> [OrderByClause]
orderByClause, Maybe Int
limitClause :: Maybe Int
limitClause :: SQLQuery -> Maybe Int
limitClause, Maybe Int
offsetClause :: Maybe Int
offsetClause :: SQLQuery -> Maybe Int
offsetClause, Text
columnsSql :: Text
columnsSql :: SQLQuery -> Text
columnsSql } =
    let -- Build the fixed prefix: SELECT [DISTINCT] [DISTINCT ON (...)] cols FROM table
        selectPart :: Text
selectPart = case Bool
distinctClause of
            Bool
True -> case Maybe Text
distinctOnClause of
                Just Text
col -> Text
"SELECT DISTINCT DISTINCT ON (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
col Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
columnsSql Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
selectFrom
                Maybe Text
Nothing  -> Text
"SELECT DISTINCT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
columnsSql Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
selectFrom
            Bool
False -> case Maybe Text
distinctOnClause of
                Just Text
col -> Text
"SELECT DISTINCT ON (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
col Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
columnsSql Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
selectFrom
                Maybe Text
Nothing  -> Text
"SELECT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
columnsSql Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
selectFrom

        -- WHERE: only append when there is a condition
        (Text
withWhere, CompilerState
cc1) = case Maybe Condition
whereCondition of
            Maybe Condition
Nothing -> (Text
selectPart, CompilerState
cc0)
            Just Condition
condition ->
                let (Text
condText, CompilerState
cc') = CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc0 Condition
condition
                in (Text
selectPart Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" WHERE " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
condText, CompilerState
cc')

        -- ORDER BY: only append when there are clauses
        withOrderBy :: Text
withOrderBy = case [OrderByClause]
orderByClause of
            [] -> Text
withWhere
            [OrderByClause]
clauses -> Text
withWhere Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" ORDER BY " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [OrderByClause] -> Text
compileOrderByClauses [OrderByClause]
clauses

        -- LIMIT: only append when set
        (Text
withLimit, CompilerState
cc2) = case Maybe Int
limitClause of
            Maybe Int
Nothing -> (Text
withOrderBy, CompilerState
cc1)
            Just Int
n ->
                let enc :: Params ()
enc = (() -> Int32) -> Params Int32 -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (Int32 -> () -> Int32
forall a b. a -> b -> a
const (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Int32)) (NullableOrNot Value Int32 -> Params Int32
forall a. NullableOrNot Value a -> Params a
Encoders.param (Value Int32 -> NullableOrNot Value Int32
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
Encoders.nonNullable Value Int32
Encoders.int4))
                    (Text
placeholder, CompilerState
cc') = Params () -> CompilerState -> (Text, CompilerState)
nextParam Params ()
enc CompilerState
cc1
                in (Text
withOrderBy Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" LIMIT " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
placeholder, CompilerState
cc')

        -- OFFSET: only append when set
        (Text
result, CompilerState
cc3) = case Maybe Int
offsetClause of
            Maybe Int
Nothing -> (Text
withLimit, CompilerState
cc2)
            Just Int
n ->
                let enc :: Params ()
enc = (() -> Int32) -> Params Int32 -> Params ()
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (Int32 -> () -> Int32
forall a b. a -> b -> a
const (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n :: Int32)) (NullableOrNot Value Int32 -> Params Int32
forall a. NullableOrNot Value a -> Params a
Encoders.param (Value Int32 -> NullableOrNot Value Int32
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
Encoders.nonNullable Value Int32
Encoders.int4))
                    (Text
placeholder, CompilerState
cc') = Params () -> CompilerState -> (Text, CompilerState)
nextParam Params ()
enc CompilerState
cc2
                in (Text
withLimit Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" OFFSET " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
placeholder, CompilerState
cc')

    in (Text
result, CompilerState
cc3)

compileCondition :: CompilerState -> Condition -> (Text, CompilerState)
compileCondition :: CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc (ColumnCondition Text
column FilterOperator
operator ConditionValue
value Maybe Text
applyLeft Maybe Text
applyRight) =
    let applyFn :: Maybe a -> a -> a
applyFn Maybe a
fn a
txt = case Maybe a
fn of
            Just a
f -> a
f a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
"(" a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
txt a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
")"
            Maybe a
Nothing -> a
txt
        colText :: Text
colText = Maybe Text -> Text -> Text
forall {a}. (Semigroup a, IsString a) => Maybe a -> a -> a
applyFn Maybe Text
applyLeft Text
column
        opText :: Text
opText = FilterOperator -> Text
compileOperator FilterOperator
operator
        (Text
valText, CompilerState
cc') = case FilterOperator
operator of
            FilterOperator
IsOp -> (Text
"NULL", CompilerState
cc)
            FilterOperator
IsNotOp -> (Text
"NULL", CompilerState
cc)
            FilterOperator
_ -> CompilerState -> ConditionValue -> (Text, CompilerState)
compileConditionValue CompilerState
cc ConditionValue
value
        valWrapped :: Text
valWrapped = case FilterOperator
operator of
            FilterOperator
InOp -> Text
"(" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
valText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"
            FilterOperator
NotInOp -> Text
"(" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
valText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"
            FilterOperator
SqlOp -> Text
valText
            FilterOperator
_ -> Maybe Text -> Text -> Text
forall {a}. (Semigroup a, IsString a) => Maybe a -> a -> a
applyFn Maybe Text
applyRight Text
valText
    in case FilterOperator
operator of
        FilterOperator
SqlOp -> (Text
colText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
valWrapped, CompilerState
cc')
        FilterOperator
_ -> (Text
colText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
opText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
valWrapped, CompilerState
cc')
compileCondition CompilerState
cc (OrCondition Condition
a Condition
b) =
    let (Text
aText, CompilerState
cc1) = CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc Condition
a
        (Text
bText, CompilerState
cc2) = CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc1 Condition
b
    in (Text
"(" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
aText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") OR (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
bText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")", CompilerState
cc2)
compileCondition CompilerState
cc (AndCondition Condition
a Condition
b) =
    let (Text
aText, CompilerState
cc1) = CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc Condition
a
        (Text
bText, CompilerState
cc2) = CompilerState -> Condition -> (Text, CompilerState)
compileCondition CompilerState
cc1 Condition
b
    in (Text
"(" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
aText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
") AND (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
bText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")", CompilerState
cc2)

compileConditionValue :: CompilerState -> ConditionValue -> (Text, CompilerState)
compileConditionValue :: CompilerState -> ConditionValue -> (Text, CompilerState)
compileConditionValue CompilerState
cc (Param Params ()
enc) = Params () -> CompilerState -> (Text, CompilerState)
nextParam Params ()
enc CompilerState
cc
compileConditionValue CompilerState
cc (Literal Text
t) = (Text
t, CompilerState
cc)

compileOrderByClauses :: [OrderByClause] -> Text
compileOrderByClauses :: [OrderByClause] -> Text
compileOrderByClauses [OrderByClause]
clauses = [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
List.intersperse Text
"," ((OrderByClause -> Text) -> [OrderByClause] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map OrderByClause -> Text
compileOrderByClause [OrderByClause]
clauses))
    where
        compileOrderByClause :: OrderByClause -> Text
compileOrderByClause OrderByClause { Text
orderByColumn :: Text
orderByColumn :: OrderByClause -> Text
orderByColumn, OrderByDirection
orderByDirection :: OrderByDirection
orderByDirection :: OrderByClause -> OrderByDirection
orderByDirection } =
            Text
orderByColumn Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (if OrderByDirection
orderByDirection OrderByDirection -> OrderByDirection -> Bool
forall a. Eq a => a -> a -> Bool
== OrderByDirection
Desc then Text
" DESC" else Text
"")

-- | Compiles a 'FilterOperator' to its SQL representation
compileOperator :: FilterOperator -> Text
compileOperator :: FilterOperator -> Text
compileOperator FilterOperator
EqOp = Text
"="
compileOperator FilterOperator
NotEqOp = Text
"!="
compileOperator FilterOperator
InOp = Text
"= ANY"
compileOperator FilterOperator
NotInOp = Text
"<> ALL"
compileOperator FilterOperator
IsOp = Text
"IS"
compileOperator FilterOperator
IsNotOp = Text
"IS NOT"
compileOperator (LikeOp MatchSensitivity
CaseSensitive) = Text
"LIKE"
compileOperator (LikeOp MatchSensitivity
CaseInsensitive) = Text
"ILIKE"
compileOperator (NotLikeOp MatchSensitivity
CaseSensitive) = Text
"NOT LIKE"
compileOperator (NotLikeOp MatchSensitivity
CaseInsensitive) = Text
"NOT ILIKE"
compileOperator (MatchesOp MatchSensitivity
CaseSensitive) = Text
"~"
compileOperator (MatchesOp MatchSensitivity
CaseInsensitive) = Text
"~*"
compileOperator FilterOperator
GreaterThanOp = Text
">"
compileOperator FilterOperator
GreaterThanOrEqualToOp = Text
">="
compileOperator FilterOperator
LessThanOp = Text
"<"
compileOperator FilterOperator
LessThanOrEqualToOp = Text
"<="
compileOperator FilterOperator
SqlOp = Text
""