{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE RankNTypes #-}

-- | Turn a 'Get' into a 'Sink' and a 'Put' into a 'Source'
-- These functions are built upno the Data.Conduit.Cereal.Internal functions with default
-- implementations of 'ErrorHandler' and 'TerminationHandler'
--
-- The default 'ErrorHandler' and 'TerminationHandler' both throw a 'GetException'.

module Data.Conduit.Cereal ( GetException
                           , sinkGet
                           , conduitGet
                           , conduitGet2
                           , sourcePut
                           , conduitPut
                           ) where

import           Control.Exception.Base
import           Control.Monad.Trans.Resource (MonadThrow, throwM)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import           Data.Conduit (ConduitT, leftover, await, yield)
import qualified Data.Conduit.List as CL
import           Data.Serialize hiding (get, put)
import           Data.Typeable

import           Data.Conduit.Cereal.Internal

data GetException = GetException String
  deriving (Int -> GetException -> ShowS
[GetException] -> ShowS
GetException -> String
(Int -> GetException -> ShowS)
-> (GetException -> String)
-> ([GetException] -> ShowS)
-> Show GetException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GetException -> ShowS
showsPrec :: Int -> GetException -> ShowS
$cshow :: GetException -> String
show :: GetException -> String
$cshowList :: [GetException] -> ShowS
showList :: [GetException] -> ShowS
Show, Typeable)

instance Exception GetException

-- | Run a 'Get' repeatedly on the input stream, producing an output stream of whatever the 'Get' outputs.
conduitGet :: MonadThrow m => Get o -> ConduitT BS.ByteString o m ()
conduitGet :: forall (m :: * -> *) o.
MonadThrow m =>
Get o -> ConduitT ByteString o m ()
conduitGet = ConduitErrorHandler m o -> Get o -> ConduitT ByteString o m ()
forall (m :: * -> *) o.
Monad m =>
ConduitErrorHandler m o -> Get o -> ConduitT ByteString o m ()
mkConduitGet ConduitErrorHandler m o
forall {m :: * -> *} {a}. MonadThrow m => String -> m a
errorHandler
  where errorHandler :: String -> m a
errorHandler String
msg = GetException -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (GetException -> m a) -> GetException -> m a
forall a b. (a -> b) -> a -> b
$ String -> GetException
GetException String
msg
{-# DEPRECATED conduitGet "Please switch to conduitGet2, see comment on that function" #-}

-- | Convert a 'Get' into a 'Sink'. The 'Get' will be streamed bytes until it returns 'Done' or 'Fail'.
--
-- If 'Get' succeed it will return the data read and unconsumed part of the input stream.
-- If the 'Get' fails due to deserialization error or early termination of the input stream it raise an error.
sinkGet :: MonadThrow m => Get r -> ConduitT BS.ByteString o m r
sinkGet :: forall (m :: * -> *) r o.
MonadThrow m =>
Get r -> ConduitT ByteString o m r
sinkGet = SinkErrorHandler m r
-> SinkTerminationHandler m r -> Get r -> ConduitT ByteString o m r
forall (m :: * -> *) r o.
Monad m =>
SinkErrorHandler m r
-> SinkTerminationHandler m r -> Get r -> ConduitT ByteString o m r
mkSinkGet String -> ConduitT ByteString o m r
SinkErrorHandler m r
forall {m :: * -> *} {a}. MonadThrow m => String -> m a
errorHandler (ByteString -> Result r) -> ConduitT ByteString o m r
SinkTerminationHandler m r
forall {m :: * -> *} {a} {o}.
MonadThrow m =>
(ByteString -> Result a) -> ConduitT ByteString o m a
terminationHandler
  where errorHandler :: String -> m a
errorHandler String
msg = GetException -> m a
forall e a. Exception e => e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (GetException -> m a) -> GetException -> m a
forall a b. (a -> b) -> a -> b
$ String -> GetException
GetException String
msg
        terminationHandler :: (ByteString -> Result a) -> ConduitT ByteString o m a
terminationHandler ByteString -> Result a
f = case ByteString -> Result a
f ByteString
BS.empty of
          Fail String
msg ByteString
_ -> GetException -> ConduitT ByteString o m a
forall e a. Exception e => e -> ConduitT ByteString o m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (GetException -> ConduitT ByteString o m a)
-> GetException -> ConduitT ByteString o m a
forall a b. (a -> b) -> a -> b
$ String -> GetException
GetException String
msg
          Done a
r ByteString
lo -> ByteString -> ConduitT ByteString o m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
lo ConduitT ByteString o m ()
-> ConduitT ByteString o m a -> ConduitT ByteString o m a
forall a b.
ConduitT ByteString o m a
-> ConduitT ByteString o m b -> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> ConduitT ByteString o m a
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
          Partial ByteString -> Result a
_ -> GetException -> ConduitT ByteString o m a
forall e a. Exception e => e -> ConduitT ByteString o m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (GetException -> ConduitT ByteString o m a)
-> GetException -> ConduitT ByteString o m a
forall a b. (a -> b) -> a -> b
$ String -> GetException
GetException String
"Failed reading: Internal error: unexpected Partial."

-- | Convert a 'Put' into a 'Source'. Runs in constant memory.
sourcePut :: Monad m => Put -> ConduitT i BS.ByteString m ()
sourcePut :: forall (m :: * -> *) i.
Monad m =>
Put -> ConduitT i ByteString m ()
sourcePut Put
put = [ByteString] -> ConduitT i ByteString m ()
forall (m :: * -> *) a i. Monad m => [a] -> ConduitT i a m ()
CL.sourceList ([ByteString] -> ConduitT i ByteString m ())
-> [ByteString] -> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString]
LBS.toChunks (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPutLazy Put
put

-- | Run a 'Putter' repeatedly on the input stream, producing a concatenated 'ByteString' stream.
conduitPut :: Monad m => Putter a -> ConduitT a BS.ByteString m ()
conduitPut :: forall (m :: * -> *) a.
Monad m =>
Putter a -> ConduitT a ByteString m ()
conduitPut Putter a
p = (a -> ByteString) -> ConduitT a ByteString m ()
forall (m :: * -> *) a b. Monad m => (a -> b) -> ConduitT a b m ()
CL.map ((a -> ByteString) -> ConduitT a ByteString m ())
-> (a -> ByteString) -> ConduitT a ByteString m ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut (Put -> ByteString) -> Putter a -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Putter a
p

-- | Reapply @Get o@ to a stream of bytes as long as more data is available,
-- and yielding each new value downstream. This has a few differences from
-- @conduitGet@:
--
-- * If there is a parse failure, the bytes consumed so far by this will not be
-- returned as leftovers. The reason for this is that the only way to guarantee
-- the leftovers will be returned correctly is to hold onto all consumed
-- @ByteString@s, which leads to non-constant memory usage.
--
-- * This function will properly terminate a @Get@ function at end of stream,
-- see https://github.com/snoyberg/conduit/issues/246.
--
-- * @conduitGet@ will pass empty @ByteString@s from the stream directly to
-- cereal, which will trigger cereal to think that the stream has been closed.
-- This breaks the normal abstraction in conduit of ignoring how data is
-- chunked. In @conduitGet2@, all empty @ByteString@s are filtered out and not
-- passed to cereal.
--
-- * After @conduitGet2@ successfully returns, we are guaranteed that there is
-- no data left to be consumed in the stream.
--
-- @since 0.7.3
conduitGet2 :: MonadThrow m => Get o -> ConduitT BS.ByteString o m ()
conduitGet2 :: forall (m :: * -> *) o.
MonadThrow m =>
Get o -> ConduitT ByteString o m ()
conduitGet2 Get o
get =
    ConduitT ByteString o m ByteString
forall {m :: * -> *} {o}.
Monad m =>
ConduitT ByteString o m ByteString
awaitNE ConduitT ByteString o m ByteString
-> (ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall a b.
ConduitT ByteString o m a
-> (a -> ConduitT ByteString o m b) -> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> ConduitT ByteString o m ()
forall {m :: * -> *}.
MonadThrow m =>
ByteString -> ConduitT ByteString o m ()
start
  where
    -- Get the next chunk of data, only returning an empty ByteString at the
    -- end of the stream.
    awaitNE :: ConduitT ByteString o m ByteString
awaitNE =
        ConduitT ByteString o m ByteString
forall {m :: * -> *} {o}.
Monad m =>
ConduitT ByteString o m ByteString
loop
      where
        loop :: ConduitT ByteString o m ByteString
loop = ConduitT ByteString o m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT ByteString o m (Maybe ByteString)
-> (Maybe ByteString -> ConduitT ByteString o m ByteString)
-> ConduitT ByteString o m ByteString
forall a b.
ConduitT ByteString o m a
-> (a -> ConduitT ByteString o m b) -> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ConduitT ByteString o m ByteString
-> (ByteString -> ConduitT ByteString o m ByteString)
-> Maybe ByteString
-> ConduitT ByteString o m ByteString
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ByteString -> ConduitT ByteString o m ByteString
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty) ByteString -> ConduitT ByteString o m ByteString
check
        check :: ByteString -> ConduitT ByteString o m ByteString
check ByteString
bs
            | ByteString -> Bool
BS.null ByteString
bs = ConduitT ByteString o m ByteString
loop
            | Bool
otherwise = ByteString -> ConduitT ByteString o m ByteString
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

    start :: ByteString -> ConduitT ByteString o m ()
start ByteString
bs
        | ByteString -> Bool
BS.null ByteString
bs = () -> ConduitT ByteString o m ()
forall a. a -> ConduitT ByteString o m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise = Result o -> ConduitT ByteString o m ()
result (Get o -> ByteString -> Result o
forall a. Get a -> ByteString -> Result a
runGetPartial Get o
get ByteString
bs)

    result :: Result o -> ConduitT ByteString o m ()
result (Fail String
msg ByteString
_) = GetException -> ConduitT ByteString o m ()
forall e a. Exception e => e -> ConduitT ByteString o m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (String -> GetException
GetException String
msg)
    -- This will feed an empty ByteString into f at end of stream, which is how
    -- we indicate to cereal that there is no data left. If we wanted to be
    -- more pedantic, we could ensure that cereal only ever consumes a single
    -- ByteString to avoid a loop, but that is the contract that cereal is
    -- giving us anyway.
    result (Partial ByteString -> Result o
f) = ConduitT ByteString o m ByteString
forall {m :: * -> *} {o}.
Monad m =>
ConduitT ByteString o m ByteString
awaitNE ConduitT ByteString o m ByteString
-> (ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall a b.
ConduitT ByteString o m a
-> (a -> ConduitT ByteString o m b) -> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result o -> ConduitT ByteString o m ()
result (Result o -> ConduitT ByteString o m ())
-> (ByteString -> Result o)
-> ByteString
-> ConduitT ByteString o m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Result o
f
    result (Done o
x ByteString
rest) = do
        o -> ConduitT ByteString o m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield o
x
        if ByteString -> Bool
BS.null ByteString
rest
            then ConduitT ByteString o m ByteString
forall {m :: * -> *} {o}.
Monad m =>
ConduitT ByteString o m ByteString
awaitNE ConduitT ByteString o m ByteString
-> (ByteString -> ConduitT ByteString o m ())
-> ConduitT ByteString o m ()
forall a b.
ConduitT ByteString o m a
-> (a -> ConduitT ByteString o m b) -> ConduitT ByteString o m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> ConduitT ByteString o m ()
start
            else ByteString -> ConduitT ByteString o m ()
start ByteString
rest