{-|
Module: IHP.WebSocket
Description: Building blocks for websocket applications
Copyright: (c) digitally induced GmbH, 2020
-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module IHP.WebSocket
( WSApp (..)
, startWSApp
, setState
, getState
, receiveData
, receiveDataMessage
, sendTextData
, sendJSON
)
where

import IHP.Prelude
import qualified Network.WebSockets as Websocket
import IHP.ApplicationContext
import IHP.Controller.RequestContext
import qualified Data.UUID as UUID
import qualified Data.Maybe as Maybe
import qualified Control.Exception as Exception
import IHP.Controller.Context
import qualified Data.Aeson as Aeson

import qualified IHP.Log.Types as Log
import qualified IHP.Log as Log

import Control.Concurrent.Chan
import Control.Concurrent
import System.Timeout
import Data.Function (fix)
import qualified Network.WebSockets.Connection as WebSocket

class WSApp state where
    initialState :: state

    run :: (?state :: IORef state, ?context :: ControllerContext, ?applicationContext :: ApplicationContext, ?modelContext :: ModelContext, ?connection :: Websocket.Connection) => IO ()
    run = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    onPing :: (?state :: IORef state, ?context :: ControllerContext, ?applicationContext :: ApplicationContext, ?modelContext :: ModelContext, ?connection :: Websocket.Connection) => IO ()
    onPing = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    onClose :: (?state :: IORef state, ?context :: ControllerContext, ?applicationContext :: ApplicationContext, ?modelContext :: ModelContext, ?connection :: Websocket.Connection) => IO ()
    onClose = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

startWSApp :: forall state. (WSApp state, ?applicationContext :: ApplicationContext, ?requestContext :: RequestContext, ?context :: ControllerContext, ?modelContext :: ModelContext) => Websocket.Connection -> IO ()
startWSApp :: Connection -> IO ()
startWSApp Connection
connection = do
    IORef state
state <- state -> IO (IORef state)
forall a. a -> IO (IORef a)
newIORef (WSApp state => state
forall state. WSApp state => state
initialState @state)
    let ?state = state
    let ?connection = connection

    let runWithPongChan :: Chan () -> IO ()
runWithPongChan Chan ()
pongChan = do
            let connectionOnPong :: IO ()
connectionOnPong = Chan () -> () -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan ()
pongChan ()
            let ?connection = connection
                    { WebSocket.connectionOptions = (get #connectionOptions connection) { WebSocket.connectionOnPong } 
                    }
                in
                    (WSApp state, ?state::IORef state, ?context::ControllerContext,
 ?applicationContext::ApplicationContext,
 ?modelContext::ModelContext, ?connection::Connection) =>
IO ()
forall state.
(WSApp state, ?state::IORef state, ?context::ControllerContext,
 ?applicationContext::ApplicationContext,
 ?modelContext::ModelContext, ?connection::Connection) =>
IO ()
run @state

    Either SomeException ()
result <- IO () -> IO (Either SomeException ())
forall e a. Exception e => IO a -> IO (Either e a)
Exception.try ((Connection -> (Chan () -> IO ()) -> IO ()
forall a. Connection -> (Chan () -> IO a) -> IO ()
withPinger Connection
connection Chan () -> IO ()
runWithPongChan) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`Exception.finally` (WSApp state, ?state::IORef state, ?context::ControllerContext,
 ?applicationContext::ApplicationContext,
 ?modelContext::ModelContext, ?connection::Connection) =>
IO ()
forall state.
(WSApp state, ?state::IORef state, ?context::ControllerContext,
 ?applicationContext::ApplicationContext,
 ?modelContext::ModelContext, ?connection::Connection) =>
IO ()
onClose @state)
    case Either SomeException ()
result of
        Left (e :: SomeException
e@Exception.SomeException{}) ->
            case SomeException -> Maybe ConnectionException
forall e. Exception e => SomeException -> Maybe e
Exception.fromException SomeException
e of
                (Just ConnectionException
Websocket.ConnectionClosed) -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                (Just (Websocket.CloseRequest {})) -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                (Just ConnectionException
other) -> Text -> IO ()
forall a. Text -> a
error (Text
"Unhandled Websocket exception: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ConnectionException -> Text
forall a. Show a => a -> Text
show ConnectionException
other)
                Maybe ConnectionException
Nothing -> Text -> IO ()
forall context string.
(?context::context, LoggingProvider context, ToLogStr string) =>
string -> IO ()
Log.error (SomeException -> Text
forall a. Show a => a -> Text
tshow SomeException
e)
        Right ()
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

setState :: (?state :: IORef state) => state -> IO ()
setState :: state -> IO ()
setState state
newState = IORef state -> state -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef ?state::IORef state
IORef state
?state state
newState

getState :: (?state :: IORef state) => IO state
getState :: IO state
getState = IORef state -> IO state
forall a. IORef a -> IO a
readIORef ?state::IORef state
IORef state
?state

receiveData :: (?connection :: Websocket.Connection, Websocket.WebSocketsData a) => IO a
receiveData :: IO a
receiveData = Connection -> IO a
forall a. WebSocketsData a => Connection -> IO a
Websocket.receiveData ?connection::Connection
Connection
?connection

receiveDataMessage :: (?connection :: Websocket.Connection) => IO Websocket.DataMessage
receiveDataMessage :: IO DataMessage
receiveDataMessage = Connection -> IO DataMessage
Websocket.receiveDataMessage ?connection::Connection
Connection
?connection

sendTextData :: (?connection :: Websocket.Connection, Websocket.WebSocketsData text) => text -> IO ()
sendTextData :: text -> IO ()
sendTextData text
text = Connection -> text -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
Websocket.sendTextData ?connection::Connection
Connection
?connection text
text

-- | Json encode a payload and send it over the websocket wire
--
-- __Example:__
--
-- > message <- Aeson.decode <$> receiveData @LByteString
-- >
-- > case message of
-- >     Just decodedMessage -> handleMessage decodedMessage
-- >     Nothing -> sendJSON FailedToDecodeMessageError
--
sendJSON :: (?connection :: Websocket.Connection, Aeson.ToJSON value) => value -> IO ()
sendJSON :: value -> IO ()
sendJSON value
payload = ByteString -> IO ()
forall text.
(?connection::Connection, WebSocketsData text) =>
text -> IO ()
sendTextData (value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode value
payload)

instance Websocket.WebSocketsData UUID where
    fromDataMessage :: DataMessage -> UUID
fromDataMessage (Websocket.Text ByteString
byteString Maybe Text
_) = ByteString -> Maybe UUID
UUID.fromLazyASCIIBytes ByteString
byteString Maybe UUID -> (Maybe UUID -> UUID) -> UUID
forall t1 t2. t1 -> (t1 -> t2) -> t2
|> Maybe UUID -> UUID
forall a. HasCallStack => Maybe a -> a
Maybe.fromJust
    fromDataMessage (Websocket.Binary ByteString
byteString) = ByteString -> Maybe UUID
UUID.fromLazyASCIIBytes ByteString
byteString Maybe UUID -> (Maybe UUID -> UUID) -> UUID
forall t1 t2. t1 -> (t1 -> t2) -> t2
|> Maybe UUID -> UUID
forall a. HasCallStack => Maybe a -> a
Maybe.fromJust
    fromLazyByteString :: ByteString -> UUID
fromLazyByteString ByteString
byteString = ByteString -> Maybe UUID
UUID.fromLazyASCIIBytes ByteString
byteString Maybe UUID -> (Maybe UUID -> UUID) -> UUID
forall t1 t2. t1 -> (t1 -> t2) -> t2
|> Maybe UUID -> UUID
forall a. HasCallStack => Maybe a -> a
Maybe.fromJust
    toLazyByteString :: UUID -> ByteString
toLazyByteString = UUID -> ByteString
UUID.toLazyASCIIBytes

data PongTimeout
    = PongTimeout
    deriving (Int -> PongTimeout -> ShowS
[PongTimeout] -> ShowS
PongTimeout -> String
(Int -> PongTimeout -> ShowS)
-> (PongTimeout -> String)
-> ([PongTimeout] -> ShowS)
-> Show PongTimeout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PongTimeout] -> ShowS
$cshowList :: [PongTimeout] -> ShowS
show :: PongTimeout -> String
$cshow :: PongTimeout -> String
showsPrec :: Int -> PongTimeout -> ShowS
$cshowsPrec :: Int -> PongTimeout -> ShowS
Show)

instance Exception PongTimeout

pingWaitTime :: Int
pingWaitTime :: Int
pingWaitTime = Int
30


-- | Pings the client every 30 seconds and expects a pong response within 10 secons. If no pong response
-- is received within 10 seconds, it will kill the connection.
--
-- We cannot use the withPingThread of the websockets package as this doesn't deal with pong messages. So
-- open connection will stay around forever.
--
-- This implementation is based on https://github.com/jaspervdj/websockets/issues/159#issuecomment-552776502
withPinger :: Connection -> (Chan () -> IO a) -> IO ()
withPinger Connection
conn Chan () -> IO a
action = do
    Chan ()
pongChan <- IO (Chan ())
forall a. IO (Chan a)
newChan
    Async a
mainAsync <- IO a -> IO (Async a)
forall a. IO a -> IO (Async a)
async (IO a -> IO (Async a)) -> IO a -> IO (Async a)
forall a b. (a -> b) -> a -> b
$ Chan () -> IO a
action Chan ()
pongChan
    Async ()
pingerAsync <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Connection -> Chan () -> IO ()
runPinger Connection
conn Chan ()
pongChan

    Async a
-> Async ()
-> IO (Either (Either SomeException a) (Either SomeException ()))
forall a b.
Async a
-> Async b
-> IO (Either (Either SomeException a) (Either SomeException b))
waitEitherCatch Async a
mainAsync Async ()
pingerAsync IO (Either (Either SomeException a) (Either SomeException ()))
-> (Either (Either SomeException a) (Either SomeException ())
    -> IO ())
-> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        -- If the application async died for any reason, kill the pinger async
        Left Either SomeException a
result -> do
            Async () -> IO ()
forall a. Async a -> IO ()
cancel Async ()
pingerAsync
            case Either SomeException a
result of
                Left SomeException
exception -> SomeException -> IO ()
forall a e. Exception e => e -> a
throw SomeException
exception
                Right a
result -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        -- The pinger thread should never throw an exception. If it does, kill the app thread
        Right (Left SomeException
exception) -> do
            Async a -> IO ()
forall a. Async a -> IO ()
cancel Async a
mainAsync
            SomeException -> IO ()
forall a e. Exception e => e -> a
throw SomeException
exception
        -- The pinger thread exited due to a pong timeout. Tell the app thread about it.
        Right (Right ()) -> Async a -> PongTimeout -> IO ()
forall e a. Exception e => Async a -> e -> IO ()
cancelWith Async a
mainAsync PongTimeout
PongTimeout

runPinger :: Connection -> Chan () -> IO ()
runPinger Connection
conn Chan ()
pongChan = (IO () -> IO ()) -> IO ()
forall a. (a -> a) -> a
fix ((IO () -> IO ()) -> IO ()) -> (IO () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \IO ()
loop -> do
    Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
Websocket.sendPing Connection
conn (ByteString
forall a. Monoid a => a
mempty :: ByteString)
    Int -> IO ()
threadDelay Int
pingWaitTime
    -- See if we got a pong in that time
    Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
1000000 (Chan () -> IO ()
forall a. Chan a -> IO a
readChan Chan ()
pongChan) IO (Maybe ()) -> (Maybe () -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just () -> IO ()
loop
        Maybe ()
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()