Haskell is quite OK for images: decoding QOI

December 18, 2021 // haskell, performance

I’ve recently come across the new “Quite OK Image” format — a fast lossless image compression algorithm. It’s a very straightforward algorithm that’s a pleasure to work with, so, naturally, I got curious what would be the performance of a Haskell implementation if:

  1. I just write reasonably efficient code without getting too deep into low-level details to get the job done in a couple of hours.
  2. I try to push the envelope and see what could be done if one’s actually willing to go into those details (within some limits, of course, so no GHC hacking!)

Turns out that yes, it’s indeed possible to write something with C-level performance in a matter of a couple of hours. Moreover, Haskell’s type system shines here: class-constrained parametric polymorphism enables using the same decoder implementation for pixels with very different representations, allowing to squeeze as much performance as is reasonably possible without duplicating the code.

In this post, I’ll describe the Haskell implementation of the decoder, and the steps I took to get from (1) to (2) for the decoder.

This time I’ll mix the sections up somewhat and start with the benchmarking results.

Benchmarks

I’ll be mostly comparing against the reference C implementation. The latter is taken at e9069e and built with gcc 11.2 or clang 13, with -O3, with or without -march=native. Our weekend Haskell brainchild is built with the latest stackage LTS (having GHC 8.10.7) with LLVM 13 as the codegen backend.

Big Fat Warning: the format has changed since then — well, they didn’t have any stability promises yet. I’m describing the results of my little experiment nevertheless.

Regarding hardware, I’m testing this on two machines. One is a dedicated server having Ryzen 7 3700X stock, and another is my desktop machine with Core i7 3930k @4.0 GHz — an oldie but goldie (although I shall really start thinking about upgrading!).

Benchmarking methodology is pretty typical, at least for me: run each binary 5 times on a sample input, take the minimum of The Numbers, put it into The Table. The sample input consists of the following:

The photo is undoubtedly not the most typical input for lossless compression algorithms, but it’s the biggest file I could get my hands on. This way, any noise does not affect the results too much. Curiously, the C version performs equally well (within measurement errors) on RGB and RGBA files. Our Haskell version, on the other hand, has a very different performance depending on the channel count.

Also, while developing my implementation, I was testing the effects of different changes on the photo only. In fact, I only got my hands on the wallpaper after finishing most of the changes, so, in machine learning lingo, the wallpaper is my test set.

All in all, it should be good enough to get a rough idea of how different implementations perform.

The Number is, correspondingly:

Even though it’s measuring slightly different quantities, the results shouldn’t differ from the actual values too much. And, this way, the measured numbers are pessimizing Haskell a little bit, which is OK if we want to see how good Haskell can be.

So, finally, the results.

Core i7 3930k

5616×3744 photo:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
C, gcc 11, -O3 -march=native 228 117% 264 102%
C, gcc 11, -O3 230 118% 260 100%
C, clang 13, -O3 -march=native 211 108% 356 137%
C, clang 13, -O3 195 100% 341 131%
Haskell, 3-channel 172 88% 219 84%
Haskell, 4-channel 187 96% 209 80%

5120x2880 artwork:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
C, gcc 11, -O3 -march=native 90 130% 73 102%
C, gcc 11, -O3 88 128% 71 100%
C, clang 13, -O3 -march=native 69 100% 124 175%
C, clang 13, -O3 70 101% 124 175%
Haskell, 3-channel 55 79% 70 99%
Haskell, 4-channel 54 78% 72 102%

It’s interesting to note that clang generates quite more efficient decoding code but totally loses at the encoding. So much for guaranteed C performance.

These results become even more interesting considering that clang and GHC both use the same code generation backend (namely, LLVM).

Ryzen 7 3700X

5616×3744 photo:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
C, gcc 11, -O3 -march=native 179 121% 198 101%
C, gcc 11, -O3 174 118% 196 100%
C, clang 13, -O3 -march=native 158 107% 266 136%
C, clang 13, -O3 148 100% 252 129%
Haskell, 3-channel 132 89% 191 97%
Haskell, 4-channel 145 98% 141 72%

5120x2880 artwork:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
C, gcc 11, -O3 -march=native 68 117% 56 102%
C, gcc 11, -O3 66 114% 55 100%
C, clang 13, -O3 -march=native 60 103% 96 175%
C, clang 13, -O3 58 100% 90 164%
Haskell, 3-channel 48 83% 54 98%
Haskell, 4-channel 48 83% 52 95%

For completeness, here are some other implementations on the same Ryzen 7 3700X for the 5616×3744 photo:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
go, go 1.17.2, 3-channel 303 205% 1079 551%
go, go 1.17.2, 4-channel 301 203% 809 412%
rust, rust 1.57.0 163 110% 177 90%

and for the 5120x2880 artwork:

Implementation Decoding, ms , % of best C Encoding, ms , % of best C
go, go 1.17.2, 3-channel 78 134% 607 1103%
go, go 1.17.2, 4-channel 78 134% 444 807%
rust, rust 1.57.0 58 100% 66 120%

A few notes:

Peeping in

As an appetizer, I’ll show a couple of versions of the decoder, where the most interesting stuff happens in Pixel.hs, defining what’s a pixel, and Decoder.hs with, well, the decoder.

First, a version that can be written in a couple of hours without too much thought about performance, without ever running with profiling enabled. Just general common sense, y’know. This one runs in 191 ms on the i7 3930k (compare to the C’s time of 228 ms with gcc and 195 ms with clang):

Pixel.hs
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE BangPatterns #-}

module Data.Image.Qoi.Pixel where

import qualified Data.Vector.Unboxed.Mutable as VM
import Data.Vector.Unboxed.Deriving
import Data.Word

data Pixel3 = Pixel3 Word8 Word8 Word8 deriving (Show)
data Pixel4 = Pixel4 Word8 Word8 Word8 Word8 deriving (Show)

derivingUnbox "Pixel3"
  [t| Pixel3 -> (Word8, Word8, Word8) |]
  [| \(Pixel3 r g b) -> (r, g, b) |]
  [| \(r, g, b) -> Pixel3 r g b |]

derivingUnbox "Pixel4"
  [t| Pixel4 -> (Word8, Word8, Word8, Word8) |]
  [| \(Pixel4 r g b a) -> (r, g, b, a) |]
  [| \(r, g, b, a) -> Pixel4 r g b a |]

class VM.Unbox a => Pixel a where
  toRGBA :: a -> (Word8, Word8, Word8, Word8)
  fromRGBA :: Word8 -> Word8 -> Word8 -> Word8 -> a

instance Pixel Pixel3 where
  toRGBA (Pixel3 r g b) = (r, g, b, 255)
  fromRGBA r g b _ = Pixel3 r g b

instance Pixel Pixel4 where
  toRGBA (Pixel4 r g b a) = (r, g, b, a)
  fromRGBA r g b a = Pixel4 r g b a

addRGB :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> pixel
addRGB px dr dg db = addRGBA px dr dg db 0

addRGBA :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> Word8 -> pixel
addRGBA px dr dg db da = let (r, g, b, a) = toRGBA px
                          in fromRGBA (r + dr) (g + dg) (b + db) (a + da)

initPixel :: Pixel pixel => pixel
initPixel = fromRGBA 0 0 0 255
Decoder.hs
{-# LANGUAGE Strict #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE BinaryLiterals #-}
{-# OPTIONS_GHC -O2 -fllvm #-}

module Data.Image.Qoi.Decoder(decodeQoi, SomePixels(..)) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as VM
import Data.Binary
import Data.Bits

import Data.Image.Qoi.Format
import Data.Image.Qoi.Pixel

infix 4 !
(!) :: BS.ByteString -> Int -> Word8
(!) = BS.unsafeIndex

infix 8 .>>., .<<.
(.>>.), (.<<.) :: Bits a => a -> Int -> a
(.>>.) = shiftR
(.<<.) = shiftL

data ChunkResult pixel
  = One pixel
  | Repeat Int
  | Lookback Int
  | Stop

peekChunk :: Pixel pixel => BS.ByteString -> Int -> pixel -> (Int, ChunkResult pixel)
peekChunk str pos prevPixel
  | byte .>>. 6 == 0     = (1, Lookback $ fromIntegral $ byte .&. 0b00111111)
  | byte .>>. 5 == 0b010 = (1, Repeat $ fromIntegral $ 1 + byte .&. 0b00011111)
  | byte .>>. 5 == 0b011 = (2, Repeat $ 33 + (fromIntegral (byte .&. 0b00011111) .<<. 8
                                          .|. fromIntegral (str ! pos + 1)
                                             )
                           )
  | byte .>>. 6 == 0b10  = let dr = (byte .>>. 4 .&. 0b11) - 2
                               dg = (byte .>>. 2 .&. 0b11) - 2
                               db = (byte        .&. 0b11) - 2
                            in (1, One $ addRGB prevPixel dr dg db)
  | byte .>>. 5 == 0b110 = let next = str ! pos + 1
                               dr = (byte .&. 0b00011111) - 16
                               dg = (next .>>. 4)         - 8
                               db = (next .&. 0b00001111) - 8
                            in (2, One $ addRGB prevPixel dr dg db)
  | byte .>>. 4 == 0b1110 = let threeBytes :: Word32
                                threeBytes = fromIntegral byte .<<. 16
                                         .|. fromIntegral (str ! pos + 1) .<<. 8
                                         .|. fromIntegral (str ! pos + 2)
                                dr = fromIntegral (threeBytes .>>. 15 .&. 0b11111) - 16
                                dg = fromIntegral (threeBytes .>>. 10 .&. 0b11111) - 16
                                db = fromIntegral (threeBytes .>>. 5  .&. 0b11111) - 16
                                da = fromIntegral (threeBytes         .&. 0b11111) - 16
                             in (3, One $ addRGBA prevPixel dr dg db da)
  | byte .>>. 4 == 0b1111 = let hr = byte .>>. 3 .&. 0b1
                                hg = byte .>>. 2 .&. 0b1
                                hb = byte .>>. 1 .&. 0b1
                                ha = byte        .&. 0b1
                                (r, g, b, a) = toRGBA prevPixel
                                r' = if hr == 1 then str ! pos + 1 else r
                                g' = if hg == 1 then str ! pos + 1 + fromIntegral hr else g
                                b' = if hb == 1 then str ! pos + 1 + fromIntegral (hr + hg) else b
                                a' = if ha == 1 then str ! pos + 1 + fromIntegral (hr + hg + hb) else a
                             in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')
  | otherwise = (0, Stop)
  where
    byte = str ! pos

maxRunLen :: Int
maxRunLen = 8224

decodePixels :: Pixel pixel => BS.ByteString -> Int -> V.Vector pixel
decodePixels str n = V.create $ do
  mvec <- VM.new $ n + maxRunLen

  running <- VM.replicate 64 initPixel

  let updateRunning px = let (r, g, b, a) = toRGBA px
                          in VM.unsafeWrite running (fromIntegral $ (r `xor` g `xor` b `xor` a) .&. 0b00111111) px

  let step inPos outPos prevPixel
        | outPos < n = do
            let (diff, chunk) = peekChunk str inPos prevPixel
            case chunk of
                 One px       -> do VM.unsafeWrite mvec outPos px
                                    updateRunning px
                                    step (inPos + diff) (outPos + 1)   px
                 Lookback pos -> do px <- VM.unsafeRead running pos
                                    VM.unsafeWrite mvec outPos px
                                    step (inPos + diff) (outPos + 1)   px
                 Repeat cnt   -> do VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel
                                    step (inPos + diff) (outPos + cnt) prevPixel
                 Stop         -> pure ()
        | otherwise = pure ()
  step 0 0 initPixel

  pure $ VM.take n mvec

data SomePixels
  = Pixels3 (V.Vector Pixel3)
  | Pixels4 (V.Vector Pixel4)

data DecodeError
  = HeaderError String
  | UnsupportedChannels Int
  | UnpaddedFile
  deriving (Show)

decodeQoi :: BS.ByteString -> Either DecodeError (Header, SomePixels)
decodeQoi str = decodeWHeader str $ decodeOrFail $ BSL.fromStrict $ BS.take 14 str
  where
    decodeWHeader _ (Left (_, _, err)) = Left $ HeaderError err
    decodeWHeader str (Right (_, consumed, header))
      | any (\i -> (str ! BS.length str - i) /= 0) [1..4] = Left UnpaddedFile
      | hChannels header == 3 = Right (header, Pixels3 decode')
      | hChannels header == 4 = Right (header, Pixels4 decode')
      | otherwise = Left $ UnsupportedChannels $ fromIntegral $ hChannels header
      where
        decode' :: Pixel pixel => V.Vector pixel
        decode' = decodePixels str (fromIntegral $ hWidth header * hHeight header)
This one, on the other hand, is the fastest one that I was able to come up with after spending a few more hours, and whose results are in the tables above:
Pixel.hs
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BinaryLiterals #-}
{-# OPTIONS_GHC -O2 -fllvm #-}

module Data.Image.Qoi.Pixel where

import qualified Data.Array.Base as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Control.Monad.ST
import Data.Bits
import Data.Word
import Foreign
import GHC.Base
import GHC.ST
import GHC.Word

import Data.Image.Qoi.Util

data Pixel3 = Pixel3 Word8 Word8 Word8 deriving (Show, Eq)
newtype Pixel4 = Pixel4 Word32 deriving (Show, Eq)

instance A.MArray (A.STUArray s) Pixel3 (ST s) where
  getBounds (A.STUArray l u _ _) = pure (l, u)
  {-# INLINE getBounds #-}
  getNumElements (A.STUArray _ _ n _) = pure n
  {-# INLINE getNumElements #-}

  newArray_ arrBounds = A.newArray arrBounds (Pixel3 0 0 0)
  {-# INLINE newArray_ #-}
  unsafeNewArray_ (l, u) = A.unsafeNewArraySTUArray_ (l, u) (*# 3#)
  {-# INLINE unsafeNewArray_ #-}

  unsafeRead (A.STUArray _ _ _ marr#) (I# n#) = ST $ \s1# ->
    let n'# = n# *# 3#
        !(# s2#, r# #) = readWord8Array# marr# n'#         s1#
        !(# s3#, g# #) = readWord8Array# marr# (n'# +# 1#) s2#
        !(# s4#, b# #) = readWord8Array# marr# (n'# +# 2#) s3#
     in (# s4#, Pixel3 (W8# r#) (W8# g#) (W8# b#) #)
  {-# INLINE unsafeRead #-}
  unsafeWrite (A.STUArray _ _ _ marr#) (I# n#) (Pixel3 (W8# r#) (W8# g#) (W8# b#)) = ST $ \s1# ->
    let n'# = n# *# 3#
        s2# = writeWord8Array# marr# n'#         r# s1#
        s3# = writeWord8Array# marr# (n'# +# 1#) g# s2#
        s4# = writeWord8Array# marr# (n'# +# 2#) b# s3#
     in (# s4#, () #)
  {-# INLINE unsafeWrite #-}

instance A.IArray A.UArray Pixel3 where
  bounds (A.UArray l u _ _) = (l, u)
  {-# INLINE bounds #-}
  numElements (A.UArray  _ _ n _) = n
  {-# INLINE numElements #-}
  unsafeArray lu ies = runST (A.unsafeArrayUArray lu ies $ Pixel3 0 0 0)
  {-# INLINE unsafeArray #-}
  unsafeAt (A.UArray _ _ _ arr#) (I# n#) = Pixel3 (W8# (indexWord8Array# arr# n'#))
                                                  (W8# (indexWord8Array# arr# (n'# +# 1#)))
                                                  (W8# (indexWord8Array# arr# (n'# +# 2#)))
    where
      n'# = n# *# 3#
  {-# INLINE unsafeAt #-}

instance A.MArray (A.STUArray s) Pixel4 (ST s) where
  getBounds (A.STUArray l u _ _) = pure (l, u)
  {-# INLINE getBounds #-}
  getNumElements (A.STUArray _ _ n _) = pure n
  {-# INLINE getNumElements #-}

  newArray_ arrBounds = A.newArray arrBounds (Pixel4 0)
  {-# INLINE newArray_ #-}
  unsafeNewArray_ (l, u) = A.unsafeNewArraySTUArray_ (l, u) (*# 4#)
  {-# INLINE unsafeNewArray_ #-}

  unsafeRead (A.STUArray _ _ _ marr#) (I# n#) = ST $ \s1# ->
    let !(# s2#, rgba# #) = readWord32Array# marr# n# s1#
     in (# s2#, Pixel4 (W32# rgba#) #)
  {-# INLINE unsafeRead #-}
  unsafeWrite (A.STUArray _ _ _ marr#) (I# n#) (Pixel4 (W32# rgba#)) = ST $ \s1# ->
    let s2# = writeWord32Array# marr# n# rgba# s1#
     in (# s2#, () #)
  {-# INLINE unsafeWrite #-}

instance A.IArray A.UArray Pixel4 where
  bounds (A.UArray l u _ _) = (l, u)
  {-# INLINE bounds #-}
  numElements (A.UArray  _ _ n _) = n
  {-# INLINE numElements #-}
  unsafeArray lu ies = runST (A.unsafeArrayUArray lu ies $ Pixel4 0)
  {-# INLINE unsafeArray #-}
  unsafeAt (A.UArray _ _ _ arr#) (I# n#) = Pixel4 (W32# (indexWord32Array# arr# n#))
  {-# INLINE unsafeAt #-}

class (Eq a, forall s. A.MArray (A.STUArray s) a (ST s)) => Pixel a where
  toRGBA :: a -> (Word8, Word8, Word8, Word8)
  fromRGBA :: Word8 -> Word8 -> Word8 -> Word8 -> a

  readPixel :: BS.ByteString -> Int -> a
  channelCount :: proxy a -> Int

bytize :: Word32 -> Word32
bytize = (.&. 0b11111111)

instance Pixel Pixel3 where
  toRGBA (Pixel3 r g b) = (r, g, b, 255)
  {-# INLINE toRGBA #-}
  fromRGBA r g b _ = Pixel3 r g b
  {-# INLINE fromRGBA #-}

  readPixel str pos = Pixel3 (str ! pos) (str ! pos + 1) (str ! pos + 2)
  {-# INLINE readPixel #-}
  channelCount _ = 3
  {-# INLINE channelCount #-}

instance Pixel Pixel4 where
  toRGBA (Pixel4 rgba) = ( fromIntegral $ rgba .>>. 24
                         , fromIntegral $ rgba .>>. 16
                         , fromIntegral $ rgba .>>. 8
                         , fromIntegral   rgba
                         )
  {-# INLINE toRGBA #-}
  fromRGBA r g b a = Pixel4 $ fromIntegral r .<<. 24
                          .|. fromIntegral g .<<. 16
                          .|. fromIntegral b .<<. 8
                          .|. fromIntegral a
  {-# INLINE fromRGBA #-}

  readPixel (BSI.PS x _ _) pos = Pixel4 $ BSI.accursedUnutterablePerformIO $ withForeignPtr x $ \p -> peek (p `plusPtr` pos)
  {-# INLINE readPixel #-}
  channelCount _ = 4
  {-# INLINE channelCount #-}

addRGB :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> pixel
addRGB px dr dg db = addRGBA px dr dg db 0
{-# INLINE addRGB #-}

addRGBA :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> Word8 -> pixel
addRGBA px dr dg db da = let (r, g, b, a) = toRGBA px
                          in fromRGBA (r + dr) (g + dg) (b + db) (a + da)
{-# INLINE addRGBA #-}

pixelHash :: (Num a, Pixel pixel) => pixel -> a
pixelHash px = fromIntegral $ (r `xor` g `xor` b `xor` a) .&. 0b00111111
  where (r, g, b, a) = toRGBA px
{-# INLINE pixelHash #-}

updateRunning :: Pixel pixel => A.STUArray s Int pixel -> pixel -> ST s ()
updateRunning running px = A.unsafeWrite running (pixelHash px) px
{-# INLINE updateRunning #-}
Decoder.hs
{-# LANGUAGE Strict #-}
{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -O2 -fllvm #-}

module Data.Image.Qoi.Decoder
( decodeQoi
, DecodeError(..)
, SomePixels(..)
) where

import qualified Data.Array.Base as A
import qualified Data.Array.ST as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Control.Monad
import Data.Binary
import Data.Binary.Get
import Data.Bits

import Data.Image.Qoi.Format
import Data.Image.Qoi.Pixel
import Data.Image.Qoi.Util

data ChunkResult pixel
  = One pixel
  | Repeat Int
  | Lookback Int
  | Stop

peekChunk :: Pixel pixel => BS.ByteString -> Int -> pixel -> (Int, ChunkResult pixel)
peekChunk str pos prevPixel
  | byte .>>. 6 == 0     = (1, Lookback $ fromIntegral $ byte .&. 0b00111111)
  | byte .>>. 5 == 0b010 = (1, Repeat $ fromIntegral $ 1 + byte .&. 0b00011111)
  | byte .>>. 5 == 0b011 = (2, Repeat $ 33 + (fromIntegral (byte .&. 0b00011111) .<<. 8
                                          .|. fromIntegral (str ! pos + 1)
                                             )
                           )
  | byte .>>. 6 == 0b10  = let dr = (byte .>>. 4 .&. 0b11) - 2
                               dg = (byte .>>. 2 .&. 0b11) - 2
                               db = (byte        .&. 0b11) - 2
                            in (1, One $ addRGB prevPixel dr dg db)
  | byte .>>. 5 == 0b110 = let next = str ! pos + 1
                               dr = (byte .&. 0b00011111) - 16
                               dg = (next .>>. 4)         - 8
                               db = (next .&. 0b00001111) - 8
                            in (2, One $ addRGB prevPixel dr dg db)
  | byte .>>. 4 == 0b1110 = let threeBytes :: Word32
                                threeBytes = fromIntegral byte .<<. 16
                                         .|. fromIntegral (str ! pos + 1) .<<. 8
                                         .|. fromIntegral (str ! pos + 2)
                                dr = fromIntegral (threeBytes .>>. 15 .&. 0b11111) - 16
                                dg = fromIntegral (threeBytes .>>. 10 .&. 0b11111) - 16
                                db = fromIntegral (threeBytes .>>. 5  .&. 0b11111) - 16
                                da = fromIntegral (threeBytes         .&. 0b11111) - 16
                             in (3, One $ addRGBA prevPixel dr dg db da)
  | byte .>>. 4 == 0b1111 = let hr = byte .>>. 3 .&. 0b1
                                hg = byte .>>. 2 .&. 0b1
                                hb = byte .>>. 1 .&. 0b1
                                ha = byte        .&. 0b1
                                (r, g, b, a) = toRGBA prevPixel
                                r' = if hr == 1 then str ! pos + 1 else r
                                g' = if hg == 1 then str ! pos + 1 + fromIntegral hr else g
                                b' = if hb == 1 then str ! pos + 1 + fromIntegral (hr + hg) else b
                                a' = if ha == 1 then str ! pos + 1 + fromIntegral (hr + hg + hb) else a
                             in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')
  | otherwise = (0, Stop)
  where
    byte = str ! pos
{-# INLINE peekChunk #-}

maxRunLen :: Int
maxRunLen = 8224

decodePixels :: Pixel pixel => BS.ByteString -> Int -> Int -> A.UArray Int pixel
decodePixels str strFrom n = A.runSTUArray $ do
  (mvec :: A.STUArray s Int pixel) <- A.unsafeNewArray_ (0, n + maxRunLen - 1)

  running <- A.newArray @(A.STUArray s) (0, 63 :: Int) initPixel

  let step inPos outPos prevPixel
        | outPos < n = do
            let (diff, chunk) = peekChunk str inPos prevPixel
            case chunk of
                 One px       -> do A.unsafeWrite mvec outPos px
                                    updateRunning running px
                                    step (inPos + diff) (outPos + 1)   px
                 Lookback pos -> do px <- A.unsafeRead running pos
                                    A.unsafeWrite mvec outPos px
                                    step (inPos + diff) (outPos + 1)   px
                 Repeat cnt   -> do forM_ [0..cnt - 1] $ \i -> A.unsafeWrite mvec (outPos + i) prevPixel
                                    step (inPos + diff) (outPos + cnt) prevPixel
                 Stop         -> pure outPos
        | otherwise = pure outPos
  finish <- step strFrom 0 initPixel

  forM_ [finish .. n - 1] $ \i -> A.unsafeWrite mvec i initPixel

  pure $ unsafeShrink mvec n
{-# INLINE decodePixels #-}

data SomePixels where
  Pixels3 :: A.UArray Int Pixel3 -> SomePixels
  Pixels4 :: A.UArray Int Pixel4 -> SomePixels
  deriving (Show)

data DecodeError
  = HeaderError String
  | UnsupportedChannels Int
  | UnpaddedFile
  deriving (Show)

decodeWHeader :: BS.ByteString
              -> Either (BSL.ByteString, ByteOffset, String) (BSL.ByteString, ByteOffset, Header)
              -> Either DecodeError (Header, SomePixels)
decodeWHeader _ (Left (_, _, err)) = Left $ HeaderError err
decodeWHeader str (Right (_, consumed, header))
  | any (\i -> (str ! BS.length str - i) /= 0) [1..4] = Left UnpaddedFile
  | hChannels header == 3 = Right (header, Pixels3 decode')
  | hChannels header == 4 = Right (header, Pixels4 decode')
  | otherwise = Left $ UnsupportedChannels $ fromIntegral $ hChannels header
  where
    decode' :: Pixel pixel => A.UArray Int pixel
    decode' = decodePixels str (fromIntegral consumed) (fromIntegral $ hWidth header * hHeight header)
    {-# INLINE decode' #-}

decodeQoi :: BS.ByteString -> Either DecodeError (Header, SomePixels)
decodeQoi str = decodeWHeader str $ decodeOrFail $ BSL.fromStrict $ BS.take 14 str
{-# NOINLINE decodeQoi #-}

Implementation

First of all, I’ll note a few implementation details, most of which apply to any computation-heavy code in general:

I’ll be using ByteString as my input: that’s what you most likely will get when reading a file, a network reply, or a mmap-ed file.

Reading the header

I’m using the binary package with a library of mine providing helpers for Generic-based derivation of Binary instances. This way, it’s sufficient to define the data structure for the format header:

data Header = Header
  { hMagic :: MatchASCII "QOI magic" "qoif"
  , hWidth :: Word32
  , hHeight :: Word32
  , hChannels :: Word8
  , hColorspace :: Word8
  } deriving (Eq, Show, Generic, Binary)

and the libraries generate all the (de)serialization boilerplate.

I’ll also define a few helper operators to make the code a tad more concise:

infix 4 !
(!) :: BS.ByteString -> Int -> Word8
(!) = BS.unsafeIndex

infix 8 .>>., .<<.
(.>>.), (.<<.) :: Bits a => a -> Int -> a
(.>>.) = shiftR
(.<<.) = shiftL

Writing str ! idx is indeed nicer than str `BS.unsafeIndex` idx, and so is byte .>>. 3 compared to byte `shiftR` 3. I’ve also chosen the priorities of the operators to avoid more parenthesis in complex expressions.

Initial attempt

Next, let’s define the basic unit of work we’ll handle: a pixel. We’ll be having two different pixels depending on the number of channels: a 3-channel RGB pixel and a 4-channel RGBA pixel:

data Pixel3 = Pixel3 Word8 Word8 Word8 deriving (Show)
data Pixel4 = Pixel4 Word8 Word8 Word8 Word8 deriving (Show)

I’m keeping these two as distinct types (as opposed to two constructors of a single Pixel type) to simplify the work for the compiler. Namely, it’s more likely to specialize the code accordingly if it statically knows which constructor is used. We’ll talk more about specialization later.

Having said that, let’s focus on RGB pixels for now. We’ll generalize to any pixels later.

Now, the compression algorithm implies we need to keep track of the last 64 pixels we’ve seen. So what options do we have here? It all comes down to three possibilities:

Thus, we’re left with unboxed mutable vectors.

The problem is that Data.Vector.Unboxed doesn’t know about our pixels, so it can’t be used with them.

At least, without some extra definitions. Those are easy to get with the vector-th-unbox library:

derivingUnbox "Pixel3"
  [t| Pixel3 -> (Word8, Word8, Word8) |]
  [| \(Pixel3 r g b) -> (r, g, b) |]
  [| \(r, g, b) -> Pixel3 r g b |]

Indeed, our Pixel3s are isomorphic to triples of Word8s, and Data.Vector.Unboxed perfectly knows how to store those.

We’ll also need a little bit of a vector space on our Pixel3s:

addPixel3 :: Pixel3 -> Pixel3 -> Pixel3
addPixel3 (Pixel3 r1 g1 b1) (Pixel3 r2 g2 b2) = Pixel3 (r1 + r2) (g1 + g2) (b1 + b2)

Now to the decoding!

To isolate pure and locally mutable parts and achieve more modularity and testability (not that I’m going to use that latter one…), we’ll write a separate function to decode a single chunk. This function takes an input position in the ByteString and the previous pixel. The (successful) decoding result is a pair of an Int and a ChunkResult. The Int says how many bytes were consumed, and the ChunkResult is the decoded chunk:

data ChunkResult pixel
  = One pixel
  | Repeat Int
  | Lookback Int

Here, ChunkResult is already parameterized by the pixel type: it doesn’t care whether it holds a Pixel3 or Pixel4.

The function for decoding a chunk is basically following the spec:

peekChunk :: BS.ByteString -> Int -> Pixel3 -> (Int, ChunkResult Pixel3)
peekChunk str pos prevPixel
  | byte .>>. 6 == 0     = (1, Lookback $ fromIntegral $ byte .&. 0b00111111)
  | byte .>>. 5 == 0b010 = (1, Repeat $ fromIntegral $ 1 + byte .&. 0b00011111)
  | byte .>>. 5 == 0b011 = (2, Repeat $ 33 + (fromIntegral (byte .&. 0b00011111) .<<. 8
                                          .|. fromIntegral (str ! pos + 1)
                                             )
                           )
  | byte .>>. 6 == 0b10  = let dr = (byte .>>. 4 .&. 0b11) - 2
                               dg = (byte .>>. 2 .&. 0b11) - 2
                               db = (byte        .&. 0b11) - 2
                            in (1, One $ prevPixel `addPixel3` Pixel3 dr dg db)
  | byte .>>. 5 == 0b110 = let next = str ! pos + 1
                               dr = (byte .&. 0b00011111) - 16
                               dg = (next .>>. 4)         - 8
                               db = (next .&. 0b00001111) - 8
                            in (2, One $ prevPixel `addPixel3` Pixel3 dr dg db)
  | byte .>>. 4 == 0b1110 = let threeBytes :: Word32
                                threeBytes = fromIntegral byte .<<. 16
                                         .|. fromIntegral (str ! pos + 1) .<<. 8
                                         .|. fromIntegral (str ! pos + 2)
                                dr = fromIntegral (threeBytes .>>. 15 .&. 0b11111) - 16
                                dg = fromIntegral (threeBytes .>>. 10 .&. 0b11111) - 16
                                db = fromIntegral (threeBytes .>>. 5  .&. 0b11111) - 16
                                --da = fromIntegral (threeBytes         .&. 0b11111) - 16
                             in (3, One $ prevPixel `addPixel3` Pixel3 dr dg db)
  | byte .>>. 4 == 0b1111 = let hr = byte .>>. 3 .&. 0b1
                                hg = byte .>>. 2 .&. 0b1
                                hb = byte .>>. 1 .&. 0b1
                                ha = byte        .&. 0b1
                                Pixel3 r g b = prevPixel
                                r' = (negate hr .&. (str ! pos + 1))
                                 .|. (hr - 1)   .&. r
                                g' = (negate hg .&. (str ! pos + 1 + fromIntegral hr))
                                 .|. (hg - 1)   .&. g
                                b' = (negate hb .&. (str ! pos + 1 + fromIntegral (hr + hg)))
                                 .|. (hb - 1)   .&. b
                                --a = negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb))
                             in (1 + fromIntegral (hr + hg + hb + ha), One $ Pixel3 r' g' b')
  | otherwise = error "unknown byte"
  where
    byte = str ! pos

Note that I already have some code for handling the fourth (alpha) channel here, but commented out. Ideally, we’ll have the same function for RGBA pixels, and the compiler will be able to comment this code out for 3-channel RGB pixels for me.

Also, this last error makes me feel uneasy. We’ll deal with it later.

Anyway, having this pure function, we can now decode the whole ByteString into a vector of precomputed length n: we just repeatedly parse a chunk until we fill the entire output array.

Implementation-wise, we almost could have used Vector’s unfoldrExactNM, which takes a function to produce the next element and the next “state” value (that’s almost our peekChunk), the seed “state” value (the 0th position in the input string and the initial pixel), and constructs the vector. Unfortunately, unfoldrExactNM expects the function to return a single element, while our peekChunk returns either a single next element (in case of One or Lookback), or it effectively returns a vector of the same elements (in case of Repeat). While we could’ve hacked our way around combining these two different styles, I expect the resulting performance to suffer.

Instead, we just manually roll our own loop:

import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as VM

peekChunk :: BS.ByteString -> Int -> Pixel3 -> (Int, ChunkResult Pixel3)
peekChunk = {- as above -}

decode3ch :: BS.ByteString -> Int -> V.Vector Pixel3
decode3ch str n = V.create $ do
  mvec <- VM.new n

  running <- VM.replicate 64 $ Pixel3 0 0 0

  let updateRunning3 px@(Pixel3 r g b) = VM.unsafeWrite running (fromIntegral $ (r `xor` g `xor` b `xor` 255) .&. 0b00111111) px

  let step inPos outPos prevPixel
        | outPos < n = do
            let (diff, chunk) = peekChunk str inPos prevPixel
            case chunk of
                 One px       -> do VM.unsafeWrite mvec outPos px
                                    updateRunning3 px
                                    step (inPos + diff) (outPos + 1)   px
                 Lookback pos -> do px <- VM.unsafeRead running pos
                                    VM.unsafeWrite mvec outPos px
                                    step (inPos + diff) (outPos + 1)   px
                 Repeat cnt   -> do VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel
                                    step (inPos + diff) (outPos + cnt) prevPixel
        | otherwise = pure ()
  step 0 0 (Pixel3 0 0 0)

  pure mvec

I’d expect this to perform reasonably well since this is obvious tail recursion, at least, after peeling off all the ST stuff. And GHC just loves to turn such tail recursion into tight loops!

Let’s write some extra wrappers and see:

decodeQoi :: BS.ByteString -> Maybe (Header, V.Vector Pixel3)
decodeQoi str = case decodeOrFail $ BSL.fromStrict $ BS.take 14 str of
                     Left _ -> Nothing
                     Right (_, _, h@Header { .. }) -> Just (h, decode3ch (BS.drop 14 str) (fromIntegral $ hWidth * hHeight))

How fast does this perform? Running 5 times on that test image and taking the minimum gives me 231 ms on my Core i7 3930k. Remember that our C baseline is 230 ms when compiled with gcc and 195 ms when compiled with clang. So I’d say that’s already a pretty decent result, but can we do better? Sure we can, but let me share some observations first.

Fixing an overrun

Actually, this implementation has an issue. I mean, the reference C implementation explicitly assumes the input is trusted, but let’s assume it’s not. In this case, an attacker can cause memory corruption by crafting such an input that one of the last chunks will be a Repeat long enough to go out of bounds in the VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel statement: in particular, the slice will be out of bounds. Not good.

How can this be fixed? We could have used VM.slice instead of the unsafe version without overthinking. There, an incorrect slice causes a runtime exception. But this means that we’re paying the safety check tax on every Repeat, and I don’t like taxes (and, moreover, it won’t help us with some later implementations).

However, a better solution would be to allocate some additional memory for our resulting array. Note that any run can be at most 8224 pixels long — that is about 32 kilobytes for an RGBA picture. So let’s just replace VM.new n with VM.new $ n + 8224 and shrink it back to n when returning from the function:

@@ -72,9 +71,12 @@ peekChunk str pos prevPixel
   where
     byte = str ! pos
 
+maxRunLen :: Int
+maxRunLen = 8224
+
 decode3ch :: BS.ByteString -> Int -> V.Vector Pixel3
 decode3ch str n = V.create $ do
-  mvec <- VM.new n
+  mvec <- VM.new $ n + maxRunLen
 
   running <- VM.replicate 64 $ Pixel3 0 0 0
 
@@ -97,7 +99,7 @@ decod3ch str n = V.create $ do
         | otherwise = pure ()
   step 0 0 (Pixel3 0 0 0)
 
-  pure mvec
+  pure $ VM.take n mvec

This way, we don’t lose any performance at all, and even if an attacker has a run of 8224 repetitions as the last element, they will just fill that extra safety padding area, and the outPos < n check will stop the loop the next iteration.

Generalizing to RGBA pixels

We actually have most of our code ready, and we just need to introduce a few abstractions that the compiler will hopefully elide.

Firstly, we’ll teach unboxed vectors of our Pixel4 type:

derivingUnbox "Pixel4"
  [t| Pixel4 -> (Word8, Word8, Word8, Word8) |]
  [| \(Pixel4 r g b a) -> (r, g, b, a) |]
  [| \(r, g, b, a) -> Pixel4 r g b a |]

Then we’ll introduce a type class for a pixel with a few methods that we need throughout the code we’ve already written. Turns out that the class itself only needs conversions to and from an RGBA 4-tuple:

class VM.Unbox a => Pixel a where
  toRGBA :: a -> (Word8, Word8, Word8, Word8)
  fromRGBA :: Word8 -> Word8 -> Word8 -> Word8 -> a

We can express the addition of an RGB triple or RGBA 4-tuple via this class:

addRGB :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> pixel
addRGB px dr dg db = addRGBA px dr dg db 0

addRGBA :: Pixel pixel => pixel -> Word8 -> Word8 -> Word8 -> Word8 -> pixel
addRGBA px dr dg db da = let (r, g, b, a) = toRGBA px
                          in fromRGBA (r + dr) (g + dg) (b + db) (a + da)

and we also add a helper initPixel which specifies the initial decoder (and, later, encoder) state:

initPixel :: Pixel pixel => pixel
initPixel = fromRGBA 0 0 0 255

The implementation of the class for Pixel4 is straightforward:

instance Pixel Pixel4 where
  toRGBA (Pixel4 r g b a) = (r, g, b, a)
  fromRGBA r g b a = Pixel4 r g b a

The implementation for Pixel3 is very similar, except we ignore input alpha and fix the output one at constant 255:

instance Pixel Pixel3 where
  toRGBA (Pixel3 r g b) = (r, g, b, 255)
  fromRGBA r g b _ = Pixel3 r g b

Then, our peekChunk type is generalized to an arbitrary Pixel pixel:

All in all, the diff looks
pretty straightforward.
-peekChunk :: BS.ByteString -> Int -> Pixel3 -> (Int, ChunkResult Pixel3)
+peekChunk :: Pixel pixel => BS.ByteString -> Int -> pixel -> (Int, ChunkResult pixel)
 peekChunk str pos prevPixel
   | byte .>>. 6 == 0     = (1, Lookback $ fromIntegral $ byte .&. 0b00111111)
   | byte .>>. 5 == 0b010 = (1, Repeat $ fromIntegral $ 1 + byte .&. 0b00011111)
@@ -43,12 +43,12 @@ peekChunk str pos prevPixel
   | byte .>>. 6 == 0b10  = let dr = (byte .>>. 4 .&. 0b11) - 2
                                dg = (byte .>>. 2 .&. 0b11) - 2
                                db = (byte        .&. 0b11) - 2
-                            in (1, One $ prevPixel `addPixel3` Pixel3 dr dg db)
+                            in (1, One $ addRGB prevPixel dr dg db)
   | byte .>>. 5 == 0b110 = let next = str ! pos + 1
                                dr = (byte .&. 0b00011111) - 16
                                dg = (next .>>. 4)         - 8
                                db = (next .&. 0b00001111) - 8
-                            in (2, One $ prevPixel `addPixel3` Pixel3 dr dg db)
+                            in (2, One $ addRGB prevPixel dr dg db)
   | byte .>>. 4 == 0b1110 = let threeBytes :: Word32
                                 threeBytes = fromIntegral byte .<<. 16
                                          .|. fromIntegral (str ! pos + 1) .<<. 8
@@ -56,60 +56,54 @@ peekChunk str pos prevPixel
                                 dr = fromIntegral (threeBytes .>>. 15 .&. 0b11111) - 16
                                 dg = fromIntegral (threeBytes .>>. 10 .&. 0b11111) - 16
                                 db = fromIntegral (threeBytes .>>. 5  .&. 0b11111) - 16
-                                --da = fromIntegral (threeBytes         .&. 0b11111) - 16
-                             in (3, One $ prevPixel `addPixel3` Pixel3 dr dg db)
+                                da = fromIntegral (threeBytes         .&. 0b11111) - 16
+                             in (3, One $ addRGBA prevPixel dr dg db da)
   | byte .>>. 4 == 0b1111 = let hr = byte .>>. 3 .&. 0b1
                                 hg = byte .>>. 2 .&. 0b1
                                 hb = byte .>>. 1 .&. 0b1
                                 ha = byte        .&. 0b1
-                                Pixel3 r g b = prevPixel
+                                (r, g, b, a) = toRGBA prevPixel
                                 r' = (negate hr .&. (str ! pos + 1))
                                  .|. (hr - 1)   .&. r
                                 g' = (negate hg .&. (str ! pos + 1 + fromIntegral hr))
                                  .|. (hg - 1)   .&. g
                                 b' = (negate hb .&. (str ! pos + 1 + fromIntegral (hr + hg)))
                                  .|. (hb - 1)   .&. b
-                                --a = negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb))
-                             in (1 + fromIntegral (hr + hg + hb + ha), One $ Pixel3 r' g' b')
+                                a' = (negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb)))
+                                 .|. (ha - 1)   .&. a
+                             in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')
   | otherwise = error "unknown byte"
   where
     byte = str ! pos

decode3ch is treated similarly and also gets renamed to decodePixels. Here it’s better to show the end result:


decodePixels :: Pixel pixel => BS.ByteString -> Int -> V.Vector pixel
decodePixels str n = V.create $ do
  mvec <- VM.new $ n + maxRunLen

  running <- VM.replicate 64 initPixel

  let updateRunning px = let (r, g, b, a) = toRGBA px
                          in VM.unsafeWrite running (fromIntegral $ (r `xor` g `xor` b `xor` a) .&. 0b00111111) px

  let step inPos outPos prevPixel
        | outPos < n = do
            let (diff, chunk) = peekChunk str inPos prevPixel
            case chunk of
                 One px       -> do VM.unsafeWrite mvec outPos px
                                    updateRunning px
                                    step (inPos + diff) (outPos + 1)   px
                 Lookback pos -> do px <- VM.unsafeRead running pos
                                    VM.unsafeWrite mvec outPos px
                                    step (inPos + diff) (outPos + 1)   px
                 Repeat cnt   -> do VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel
                                    step (inPos + diff) (outPos + cnt) prevPixel
        | otherwise = pure ()
  step 0 0 initPixel

  pure $ VM.take n mvec

How well does the code perform right now? 230 ms. No change at all. Indeed, the compiler is perfectly happy to inline and specialize everything since it knows the exact types at call sites, so we’re on the right path.

The remaining problem is that we’re still bound to decoding 3-channel images, even if our decoding functions are general enough. One last piece is updating decodeQoi to return either a V.Vector Pixel3 or a V.Vector Pixel4, depending on the input. Let’s do that, quick and dirty style:

data SomePixels
  = Pixels3 (V.Vector Pixel3)
  | Pixels4 (V.Vector Pixel4)

decodeQoi :: BS.ByteString -> Maybe (Header, SomePixels)
decodeQoi str = case decodeOrFail $ BSL.fromStrict $ BS.take 14 str of
                     Left _ -> Nothing
                     Right (_, _, h@Header { .. }) -> if hChannels == 3
                                                         then Just (h, Pixels3 $ decodePixels (BS.drop 14 str) (fromIntegral $ hWidth * hHeight))
                                                         else Just (h, Pixels4 $ decodePixels (BS.drop 14 str) (fromIntegral $ hWidth * hHeight))

How fast is that? 233 ms. Just a tad slower, yet entirely within measurement errors.

Better error handling

There are a few things to fix about error handling.

Firstly, let’s get rid of that ugly error. Let’s add one more constructor to ChunkResult called Stop, and, well, just stop decoding when we encounter that:

@@ -31,6 +31,7 @@ data ChunkResult pixel
   = One pixel
   | Repeat Int
   | Lookback Int
+  | Stop
 
 peekChunk :: Pixel pixel => BS.ByteString -> Int -> pixel -> (Int, ChunkResult pixel)
 peekChunk str pos prevPixel
@@ -72,7 +73,7 @@ peekChunk str pos prevPixel
                                 a' = (negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb)))
                                  .|. (ha - 1)   .&. a
                              in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')
-  | otherwise = error "unknown byte"
+  | otherwise = (0, Stop)
   where
     byte = str ! pos
 
@@ -98,6 +99,7 @@ decodePixels str n = V.create $ do
                  Repeat cnt   -> do VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel
                                     step (inPos + diff) (outPos + cnt) prevPixel
+                 Stop         -> pure ()
         | otherwise = pure ()
   step 0 0 initPixel

The effect on The Numbers? Again within measurement error (although it even seems to be a bit faster): 228 ms.

Then, the unsupported channel count is straightforward: we’ll just explicitly check it before decoding.

Finally, we need to check for overruns. Luckily, the format is well-thought, so this check is effectively free. In particular, note we’re only ever reading up to four bytes past the current byte (when decoding the full RGBA color), and it only happens when the current byte is non-null. So, the only thing we need to do is to check that the file is indeed padded!

Doing all that and replacing the Maybe with a proper Either, we get a nice and beautiful decodeQoi:

data DecodeError
  = HeaderError String
  | UnsupportedChannels Int
  | UnpaddedFile
  deriving (Show)

decodeQoi :: BS.ByteString -> Either DecodeError (Header, SomePixels)
decodeQoi str = decodeWHeader str $ decodeOrFail $ BSL.fromStrict $ BS.take 14 str
  where
    decodeWHeader _ (Left (_, _, err)) = Left $ HeaderError err
    decodeWHeader str (Right (_, consumed, header))
      | any (\i -> (str ! BS.length str - i) /= 0) [1..4] = Left UnpaddedFile
      | hChannels header == 3 = Right (header, Pixels3 decode')
      | hChannels header == 4 = Right (header, Pixels4 decode')
      | otherwise = Left $ UnsupportedChannels $ fromIntegral $ hChannels header
      where
        decode' :: Pixel pixel => V.Vector pixel
        decode' = decodePixels str (fromIntegral $ hWidth header * hHeight header)

Effect on numbers? None: 230 ms.

Exercise for the reader: there still is a condition that can arguably be considered a soft error. Namely, the chunks stream can be appropriately padded, but it can just have fewer pixels than hWidth and hHeight tell us. Right now, the rest of the pixels are set to black. How would you change the code to account for this?

Basically, that’s it. It has pretty decent C-level performance, it doesn’t have too much low-level nonsense, and it’s reasonably readable. If I were a maintainer of an image formats library and I got a pull request with some code like that, I’d be happy to merge.

Minor tweaks

But if we want to be faster, other things are worth trying out.

Unsafe new

Firstly, we’re using V.new that initializes the underlying memory. What if we used V.unsafeNew that doesn’t? Well, it kinda improves things, but again within error: the running time becomes 227-228 ms.

Memory layout

Then, let’s take a closer look at how exactly we are storing pixels. Right now, we’re effectively storing them using the unboxed vector instance for triples. V.Vector (a, b, c), in turn, is effectively a triple of (V.Vector a, V.Vector b, V.Vector c), so it’s like we’ve got the Array of Structs → Struct of Arrays transformation for free. But is it really beneficial in our case? Let’s write the instances ourselves and see!

Sequentially storing Pixel4s is trivial (and we’re not benchmarking those anyway just yet): we just pack them as 32-bit words:

derivingUnbox "Pixel4"
  [t| Pixel4 -> Word32 |]
  [| \(Pixel4 r g b a) -> (fromIntegral r `shiftL` 24)
                      .|. (fromIntegral g `shiftL` 16)
                      .|. (fromIntegral b `shiftL` 8)
                      .|.  fromIntegral a
                      |]
  [| \w32 -> Pixel4 (fromIntegral $ w32 `shiftR` 24)
                    (fromIntegral $ w32 `shiftR` 16)
                    (fromIntegral $ w32 `shiftR` 8)
                    (fromIntegral   w32)
                    |]

Storing Pixel3s is more interesting, as we now have to do all the heavy lifting ourselves. So, we remove derivingUnbox "Pixel3" and write explicit instances:

newtype instance VM.MVector s Pixel3 = MV_Pixel3 { getMVP3 :: VM.MVector s Word8 }
newtype instance V.Vector     Pixel3 = V_Pixel3  { getVP3  :: V.Vector Word8 }

instance VG.Vector V.Vector Pixel3 where
  basicUnsafeFreeze = fmap V_Pixel3 . VG.basicUnsafeFreeze . getMVP3
  basicUnsafeThaw = fmap MV_Pixel3 . VG.basicUnsafeThaw . getVP3
  basicLength = (`div` 3) . VG.basicLength . getVP3
  basicUnsafeSlice s l = V_Pixel3 . VG.basicUnsafeSlice (s * 3) (l * 3) . getVP3
  basicUnsafeIndexM (V_Pixel3 vec) idx = Pixel3 <$> VG.basicUnsafeIndexM vec idx'
                                                <*> VG.basicUnsafeIndexM vec (idx' + 1)
                                                <*> VG.basicUnsafeIndexM vec (idx' + 2)
    where
      idx' = idx * 3
  elemseq _ !px b = b

instance VMG.MVector VM.MVector Pixel3 where
  basicLength = (`div` 3) . VMG.basicLength . getMVP3
  basicUnsafeSlice s l = MV_Pixel3 . VMG.basicUnsafeSlice (s * 3) (l * 3) . getMVP3
  basicOverlaps (MV_Pixel3 v1) (MV_Pixel3 v2) = VMG.basicOverlaps v1 v2
  basicUnsafeNew = fmap MV_Pixel3 . VMG.basicUnsafeNew . (* 3)
  basicInitialize = VMG.basicInitialize . getMVP3
  basicUnsafeRead (MV_Pixel3 vec) idx = Pixel3 <$> VMG.basicUnsafeRead vec idx'
                                               <*> VMG.basicUnsafeRead vec (idx' + 1)
                                               <*> VMG.basicUnsafeRead vec (idx' + 2)
    where
      idx' = idx * 3
  basicUnsafeWrite (MV_Pixel3 vec) idx (Pixel3 r g b) = VMG.basicUnsafeWrite vec idx' r
                                                     >> VMG.basicUnsafeWrite vec (idx' + 1) g
                                                     >> VMG.basicUnsafeWrite vec (idx' + 2) b
    where
      idx' = idx * 3

instance VM.Unbox Pixel3

It has a statistically significant, albeit small, effect on the run time: 224 ms. Perhaps not worth the extra complexity and room for error.

Bytestring ultra-hardcore

There’s this little problem with ByteString that forces me to roll up my sleeves and do some ugly work almost every time I’m trying to write high-performance code in Haskell.

You see, ByteString supports the notion of slices. Slices allow functions like BS.take or BS.splitAt to be O(1) time and memory: they don’t copy the source string. That’s undoubtedly a good thing.

But how are those slices (thus, ByteStrings in general) represented? Well, frankly, quite suboptimally: each string is, besides other fields, a pointer p to the start of the whole string and an offset s into the beginning of this slice. Hence, even if the string is something you’ve just created (like you’ve read from a file or BS.packed a bunch of bytes), you’ll still have this s field, even if it’s 0.

What does this mean performance-wise? Well, this means that to access the ith symbol, you need to do p + s + i — that’s two additions instead of one that you’d generally expect. And even if you’ve never taken any slices, the compiler isn’t smart enough to figure out that s is always 0 and elide one of the additions.

Luckily, this has finally been fixed in bytestring-0.11 released this year, where a ByteString is now a pointer into the beginning of the slice without any extra offsets, as it should be.

Unfortunately, all the Stackage snapshots, including nightlies, are still using bytestring-0.10, so we’ll need to backport things. For our quick and dirty implementation, it’s a matter of import Data.ByteString.Internal as BSI and replacing the definition of our (!) with an expression calling quite a funnily named function:

(BSI.PS x _ _) ! i = BSI.accursedUnutterablePerformIO $ withForeignPtr x $ \p -> peekByteOff p i

The effect? Quite noticeable: 208 ms.

Sure, we now need to adjust the pointer at the calling code to ensure that the offset is, in fact, zero, but we have to only do this once, as opposed to an extra sum on every iteration. Doing this is quite trivial and is a matter of using the following helper function:

unoffsetBS :: BS.ByteString -> BS.ByteString
unoffsetBS (BSI.PS ptr offset len) = BSI.PS (ptr `plusForeignPtr` offset) 0 len

I also can’t help but notice that this saves us about 53 million extra additions (the sample file is about 53 megabytes, and we read each byte once), which, on my 4 GHz machine, shall take approximately 53 × 1048536 / (4 × 109) = 14 ms, assuming one cycle per instruction. Indeed, 224 - 208 = 16 ms is close to our ballpark estimate of an extra 14 ms. It’s nice to see for once that expectations match reality this good, given how complicated modern CPUs are.

Array ultra-hardcore

Interestingly, the vector family of types has the same problem peculiarity as the ByteString. Sure, we could’ve gone a similar route and implemented a restricted sliceless indexing operator ourselves. But, unfortunately, it’s not entirely trivial: the definitions of quite some operations are expressed in terms of *.Generic interfaces, and I really don’t want to go untangle all that.

Instead, we can go one level down and use the array library instead of vector. It doesn’t support slices in the first place, and, although it supports non-zero-indexed arrays, unsafe indexing throws that out of the window, so, long story short, arrays are the way to go.

First, we’ll teach our unboxed arrays how to store Pixel3s. Hence, we need to write instances of MArray (STUArray s) Pixel3 (ST s) and IArray UArray Pixel3. This time, it involves a bunch of ugly-looking #-rich low-level primitive code:

import qualified Data.Array.Base as A
import qualified Data.Array.MArray as A
import qualified Data.Array.ST as A

instance A.MArray (A.STUArray s) Pixel3 (ST s) where
  getBounds (A.STUArray l u _ _) = pure (l, u)
  {-# INLINE getBounds #-}
  getNumElements (A.STUArray _ _ n _) = pure n
  {-# INLINE getNumElements #-}

  newArray_ arrBounds = A.newArray arrBounds (Pixel3 0 0 0)
  {-# INLINE newArray_ #-}
  unsafeNewArray_ (l, u) = A.unsafeNewArraySTUArray_ (l, u) (*# 3#)
  {-# INLINE unsafeNewArray_ #-}

  unsafeRead (A.STUArray _ _ _ marr#) (I# n#) = ST $ \s1# ->
    let n'# = n# *# 3#
        (# s2#, r# #) = readWord8Array# marr# n'#         s1#
        (# s3#, g# #) = readWord8Array# marr# (n'# +# 1#) s2#
        (# s4#, b# #) = readWord8Array# marr# (n'# +# 2#) s3#
     in (# s4#, Pixel3 (W8# r#) (W8# g#) (W8# b#) #)
  {-# INLINE unsafeRead #-}
  unsafeWrite (A.STUArray _ _ _ marr#) (I# n#) (Pixel3 (W8# r#) (W8# g#) (W8# b#)) = ST $ \s1# ->
    let n'# = n# *# 3#
        s2# = writeWord8Array# marr# n'#         r# s1#
        s3# = writeWord8Array# marr# (n'# +# 1#) g# s2#
        s4# = writeWord8Array# marr# (n'# +# 2#) b# s3#
     in (# s4#, () #)
  {-# INLINE unsafeWrite #-}

instance A.IArray A.UArray Pixel3 where
  bounds (A.UArray l u _ _) = (l, u)
  numElements (A.UArray  _ _ n _) = n
  unsafeArray lu ies = runST (A.unsafeArrayUArray lu ies $ Pixel3 0 0 0)
  unsafeAt (A.UArray _ _ _ arr#) (I# n#) = Pixel3 (W8# (indexWord8Array# arr# n'#))
                                                  (W8# (indexWord8Array# arr# (n'# +# 1#)))
                                                  (W8# (indexWord8Array# arr# (n'# +# 2#)))
    where
      n'# = n# *# 3#
Similarly for Pixel4s.
data Pixel4 = Pixel4 Word8 Word8 Word8 Word8 deriving (Show)

instance A.MArray (A.STUArray s) Pixel4 (ST s) where
  getBounds (A.STUArray l u _ _) = pure (l, u)
  {-# INLINE getBounds #-}
  getNumElements (A.STUArray _ _ n _) = pure n
  {-# INLINE getNumElements #-}

  newArray_ arrBounds = A.newArray arrBounds (Pixel4 0 0 0 0)
  {-# INLINE newArray_ #-}
  unsafeNewArray_ (l, u) = A.unsafeNewArraySTUArray_ (l, u) (*# 4#)
  {-# INLINE unsafeNewArray_ #-}

  unsafeRead (A.STUArray _ _ _ marr#) (I# n#) = ST $ \s1# ->
    let n'# = n# *# 4#
        (# s2#, r# #) = readWord8Array# marr# n'#         s1#
        (# s3#, g# #) = readWord8Array# marr# (n'# +# 1#) s2#
        (# s4#, b# #) = readWord8Array# marr# (n'# +# 2#) s3#
        (# s5#, a# #) = readWord8Array# marr# (n'# +# 3#) s4#
     in (# s5#, Pixel4 (W8# r#) (W8# g#) (W8# b#) (W8# a#) #)
  {-# INLINE unsafeRead #-}
  unsafeWrite (A.STUArray _ _ _ marr#) (I# n#) (Pixel4 (W8# r#) (W8# g#) (W8# b#) (W8# a#)) = ST $ \s1# ->
    let n'# = n# *# 4#
        s2# = writeWord8Array# marr# n'#         r# s1#
        s3# = writeWord8Array# marr# (n'# +# 1#) g# s2#
        s4# = writeWord8Array# marr# (n'# +# 2#) b# s3#
        s5# = writeWord8Array# marr# (n'# +# 3#) a# s4#
     in (# s5#, () #)
  {-# INLINE unsafeWrite #-}

instance A.IArray A.UArray Pixel4 where
  bounds (A.UArray l u _ _) = (l, u)
  numElements (A.UArray  _ _ n _) = n
  unsafeArray lu ies = runST (A.unsafeArrayUArray lu ies $ Pixel4 0 0 0 0)
  unsafeAt (A.UArray _ _ _ arr#) (I# n#) = Pixel4 (W8# (indexWord8Array# arr# n'#))
                                                  (W8# (indexWord8Array# arr# (n'# +# 1#)))
                                                  (W8# (indexWord8Array# arr# (n'# +# 2#)))
                                                  (W8# (indexWord8Array# arr# (n'# +# 3#)))
    where
      n'# = n# *# 4#
This is actually the worst part. With these changes, updating the decoder is straightforward: the pure peekChunk is not affected at all, and the rest of the changes are
primarily cosmetic.
-decodePixels :: Pixel pixel => BS.ByteString -> Int -> Int -> V.Vector pixel
-decodePixels str strFrom n = V.create $ do
-  mvec <- VM.new $ n + maxRunLen
+decodePixels :: Pixel pixel => BS.ByteString -> Int -> Int -> A.UArray Int pixel
+decodePixels str strFrom n = A.runSTUArray $ do
+  (mvec :: A.STUArray s Int pixel) <- A.unsafeNewArray_ (0, n + maxRunLen - 1)
 
-  running <- VM.replicate 64 initPixel
+  running <- A.newArray @(A.STUArray s) (0, 63 :: Int) initPixel
 
   let updateRunning px = let (r, g, b, a) = toRGBA px
-                          in VM.unsafeWrite running (fromIntegral $ (r `xor` g `xor` b `xor` a) .&. 0b00111111) px
+                          in A.unsafeWrite running (fromIntegral $ (r `xor` g `xor` b `xor` a) .&. 0b00111111) px
 
   let step inPos outPos prevPixel
         | outPos < n = do
             let (diff, chunk) = peekChunk str inPos prevPixel
             case chunk of
-                 One px       -> do VM.unsafeWrite mvec outPos px
+                 One px       -> do A.unsafeWrite mvec outPos px
                                     updateRunning px
                                     step (inPos + diff) (outPos + 1)   px
-                 Lookback pos -> do px <- VM.unsafeRead running pos
-                                    VM.unsafeWrite mvec outPos px
+                 Lookback pos -> do px <- A.unsafeRead running pos
+                                    A.unsafeWrite mvec outPos px
                                     step (inPos + diff) (outPos + 1)   px
-                 Repeat cnt   -> do VM.set (VM.unsafeSlice outPos cnt mvec) prevPixel
+                 Repeat cnt   -> do forM_ [0..cnt - 1] $ \i -> A.unsafeWrite mvec (outPos + i) prevPixel
                                     step (inPos + diff) (outPos + cnt) prevPixel
                  Stop         -> pure ()
         | otherwise = pure ()
 
-  pure $ VM.take n mvec
+  pure $ unsafeShrink mvec n
 
 data SomePixels
-  = Pixels3 (V.Vector Pixel3)
-  | Pixels4 (V.Vector Pixel4)
+  = Pixels3 (A.UArray Int Pixel3)
+  | Pixels4 (A.UArray Int Pixel4)
 
 decodeQoi :: BS.ByteString -> Maybe (Header, SomePixels)
 decodeQoi str
@@ -126,5 +131,5 @@ decodeQoi str
                                    , hHeight = fromBigEndian hHeight
                                    , ..
                                    }
-    decode' :: Pixel pixel => V.Vector pixel
-    decode' = decodePixels str consumed (fromIntegral $ hWidth header * hHeight header)
+    decode' :: Pixel pixel => A.UArray Int pixel
+    decode' = decodePixels str consumed (fromIntegral $ hWidth header * hHeight header)

Here, unsafeShrink basically adjusts the length of the array without copying. We implement this function ourselves:

unsafeShrink :: A.STUArray s Int e -> Int -> A.STUArray s Int e
unsafeShrink arr@(A.STUArray l _ n marr) cnt
  | cnt >= n = arr
  | otherwise = A.STUArray l (l + cnt - 1) cnt marr
Not sure what’s unsafe about it — it even does the proper check after all. I probably should have just named it shrink.

Anyway, this one runs in 190 ms. So, once again, we’ve shaved off slightly more than 16 ms, as expected.

Is it worth it? Well, it’s almost 10% of run time, and all the horror hides in the instances for unboxed arrays for Pixels, separate from the decoding/encoding logic, nicely testable, and so on. Hence I think it’s definitely worth it.

Alright, that’s already a total win. But let’s give it a couple more finishing touches.

Branchless bit-shuffling

There’re a few funny lines above that made me feel very proud of myself when I wrote them.

In the part of the decoder that handles the full encoded color, for each color component, we need to check a specific bit (3rd for red, 2nd for green, etc.), and, if it’s set, read the corresponding byte, otherwise use the previous value. So here’s how it happens in your code right now, dear reader, if you’re following me:

peekChunk str pos prevPixel
  | ...
  | byte .>>. 4 == 0b1111 = let hr = byte .>>. 3 .&. 0b1
                                hg = byte .>>. 2 .&. 0b1
                                hb = byte .>>. 1 .&. 0b1
                                ha = byte        .&. 0b1
                                (r, g, b, a) = toRGBA prevPixel
                                r' = (negate hr .&. (str ! pos + 1))
                                 .|. (hr - 1)   .&. r
                                g' = (negate hg .&. (str ! pos + 1 + fromIntegral hr))
                                 .|. (hg - 1)   .&. g
                                b' = (negate hb .&. (str ! pos + 1 + fromIntegral (hr + hg)))
                                 .|. (hb - 1)   .&. b
                                a' = (negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb)))
                                 .|. (ha - 1)   .&. a
                             in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')

Let’s focus on the red channel. hr denotes whether the bit for the red color is set. Now, if it’s set, we need to use the next byte (str ! pos + 1), otherwise we need to use r.

I was cautious to avoid branching here because, as we all know, modern CPUs hate branching. This is why I wrote this abomination:

r' = (negate hr .&. (str ! pos + 1))
 .|. (hr - 1)   .&. r

What the hell is this? Well, let’s analyze the cases:

So, we have no branches, and we get the desired result. Sure, we have to do slightly more operations, but they are just basic arithmetic on single bytes; thus, they are super cheap. Furthermore, the overall intuition is that any decent CPU will compute the two sides of .|. in parallel. Hence we’ll add at most, ugh, negate and a few integer additions, so just a few operations to our dependency chain, and we won’t really need our new pixel until we parse out the next chunk from the input stream, which won’t happen “soon” cycles-wise.

Nice. Now, I’m just curious, how much does this bit-hackery give us? Let’s replace this with stupid branches and measure!

                                 hb = byte .>>. 1 .&. 0b1
                                 ha = byte        .&. 0b1
                                 (r, g, b, a) = toRGBA prevPixel
-                                r' = (negate hr .&. (str ! pos + 1))
-                                 .|. (hr - 1)   .&. r
-                                g' = (negate hg .&. (str ! pos + 1 + fromIntegral hr))
-                                 .|. (hg - 1)   .&. g
-                                b' = (negate hb .&. (str ! pos + 1 + fromIntegral (hr + hg)))
-                                 .|. (hb - 1)   .&. b
-                                a' = (negate ha .&. (str ! pos + 1 + fromIntegral (hr + hg + hb)))
-                                 .|. (ha - 1)   .&. a
+                                r' = if hr == 1 then str ! pos + 1 else r
+                                g' = if hg == 1 then str ! pos + 1 + fromIntegral hr else g
+                                b' = if hb == 1 then str ! pos + 1 + fromIntegral (hr + hg) else b
+                                a' = if ha == 1 then str ! pos + 1 + fromIntegral (hr + hg + hb) else a
                              in (1 + fromIntegral (hr + hg + hb + ha), One $ fromRGBA r' g' b' a')
   | otherwise = (0, Stop)
   where

Turns out, the difference is enormous! But… but it’s in the other direction. The code with branching is faster. It’s even faster even for the 3-channel RGB data, where the compiler elides X having the longest dependency chain (due to the highest number of additions).

How much faster? This stuff runs in 172 ms instead of the 190 ms we previously had. That’s quite a difference, indeed!

Now, what if we apply the same treatment to our Vector-using version, without all the low-level stuff with Arrays or ByteString guts? 191 ms. Even somewhat faster than the best C results.

Conclusion and next steps

In this post, we’ve started with a reasonably high-level and quick-and-dirty implementation of a QOI decoder, which turned out to be on par with the official C version.

With a few changes, some of which are universally applicable (like using Arrays instead of Vector), we managed to further improve the performance of our Haskell version to be around 10-20% faster than the best C times in most cases.

In the next post:

Stay tuned!