diff --git a/crypto-sodium/crypto-sodium.cabal b/crypto-sodium/crypto-sodium.cabal index 69146a2..d09f055 100644 --- a/crypto-sodium/crypto-sodium.cabal +++ b/crypto-sodium/crypto-sodium.cabal @@ -134,6 +134,7 @@ test-suite test Test.Crypto.Sodium.Pwhash Test.Crypto.Sodium.Random Test.Crypto.Sodium.Salt + Test.Crypto.Sodium.Salt.Internal Test.Crypto.Sodium.Sign Paths_crypto_sodium hs-source-dirs: diff --git a/crypto-sodium/lib/Crypto/Sodium/Salt/Internal.hs b/crypto-sodium/lib/Crypto/Sodium/Salt/Internal.hs index 4e614b4..7e95899 100644 --- a/crypto-sodium/lib/Crypto/Sodium/Salt/Internal.hs +++ b/crypto-sodium/lib/Crypto/Sodium/Salt/Internal.hs @@ -7,8 +7,10 @@ module Crypto.Sodium.Salt.Internal ( parseEscapes ) where +import Control.Monad (liftM2) +import Data.Char (isSpace) import Data.Maybe (listToMaybe) -import Text.ParserCombinators.ReadP (eof, readP_to_S, many) +import Text.ParserCombinators.ReadP (ReadP, readP_to_S, (<++)) import Text.Read.Lex (lexChar) -- | Parse a Haskell string literal with escapes. @@ -19,6 +21,16 @@ import Text.Read.Lex (lexChar) -- -- This function can fail if there are invalid escape sequences. parseEscapes :: MonadFail m => String -> m String -parseEscapes str = case listToMaybe (readP_to_S (many lexChar <* eof) str) of +parseEscapes str = case listToMaybe (readP_to_S (many' lexChar) str) of Just (result, "") -> pure result - _ -> fail $ "Failed to parse raw bytes (no parse): " <> str + Just (_, rest) -> fail $ case rest of + '\\':rest' -> "Failed to parse character escape '\\" + <> takeWhile (\c -> not (isSpace c) && c /= '\\') rest' <> "'" + -- the next case shouldn't happen since 'lexChar' can only fail on escapes + _ -> "Failed to parse string with escapes" + <> " at input position " <> show (length str - length rest) + -- the last case shouldn't happen since parser will happily parse zero characters + _ -> fail $ "Failed to parse string with escapes: " <> str + +many' :: ReadP a -> ReadP [a] +many' p = liftM2 (:) p (many' p) <++ pure [] diff --git a/crypto-sodium/test/Test/Crypto/Sodium/Salt/Internal.hs b/crypto-sodium/test/Test/Crypto/Sodium/Salt/Internal.hs new file mode 100644 index 0000000..754d53a --- /dev/null +++ b/crypto-sodium/test/Test/Crypto/Sodium/Salt/Internal.hs @@ -0,0 +1,40 @@ +-- SPDX-FileCopyrightText: 2022 Serokell +-- +-- SPDX-License-Identifier: MPL-2.0 +{-# LANGUAGE DerivingStrategies #-} + +module Test.Crypto.Sodium.Salt.Internal where + +import Test.HUnit ((@?=), Assertion) + +import Crypto.Sodium.Salt.Internal (parseEscapes) + +newtype ErrM a = ErrM (Either String a) + deriving stock (Show, Eq) + deriving newtype (Functor, Applicative, Monad) + +instance MonadFail ErrM where + fail = ErrM . Left + +ok :: a -> ErrM a +ok = ErrM . Right + +err :: String -> ErrM a +err = ErrM . Left + +unit_parseEscapes_noEscapes, unit_parseEscapes_goodEscapes, unit_parseEscapes_emptyString, + unit_parseEscapes_badEscape, unit_parseEscapes_badEscapeMidLine, + unit_parseEscapes_emptyEscape :: Assertion +unit_parseEscapes_noEscapes = + parseEscapes "no escapes" @?= ok "no escapes" +unit_parseEscapes_goodEscapes = + parseEscapes "\\123\\456" @?= ok "\123\456" +unit_parseEscapes_emptyString = + parseEscapes "" @?= ok "" +unit_parseEscapes_badEscape = + parseEscapes "\\err" @?= err "Failed to parse character escape '\\err' at input position 0" +unit_parseEscapes_badEscapeMidLine = + parseEscapes "some text \\somearbitrarystring other text" + @?= err "Failed to parse character escape '\\somearbitrarystring' at input position 10" +unit_parseEscapes_emptyEscape = + parseEscapes "\\" @?= err "Failed to parse character escape '\\' at input position 0"