Writing a fast edit distance implementation

January 1, 2020 // ,

In this post we will implement the Levenshtein edit distance algorithm, aiming at a reasonably performant code. We will start with a more or less idiomatic pure implementation and see how various changes (including strictness annotations or compilation options) affect the performance.

We will also compare this to a baseline C++ implementation.

Spoiler alerts:

Long story short, the Levenshtein algorithm takes two strings and computes how many characters have to be inserted, deleted or changed in order to make those strings match.

We will follow what’s known as the dynamic programming approach with two matrix rows, having O(mn) time complexity and O(min(m, n)) space complexity, where m and n are the lengths of the respective input strings.

Motivation

There are already a few implementations of the edit distance, namely the one in the text-metrics package as well as the ultra-fast edit-distance.

There were some reasons for me to write yet another implementation besides the obvious one:

  1. text-metrics works with Data.Text, so there must be some performance penalty related to dealing with Unicode. In short, I didn’t care about non-ASCII text, so I didn’t need that penalty.
  2. edit-distance uses a square STUArray (a full-matrix approach), thus having a different space complexity. I’m sort of ashamed to admit that the previous sentence is actually incorrect, and it’s just me reading the code wrongly — in fact the underlying algorithm there is way, way more complex.
  3. And, most importantly, writing code is fun, trying to make code fast is twice as fun, learning from that experience is infinitely fun!

Anyway, here we go.

Benchmarking

We will have two ways of benchmarking this. The first one is with the criterion package that we all know and love. But running criterion is slow, so we’ll also have a really stupid benchmark that generates three strings:

and then computes the distance from s1 to s2 (which is obviously 0) and the distance from s1 to s3 (which is 20000).

Since preparing the benchmark data is dwarfed by the run time, we will just measure how long the program runs (via +RTS -sstderr) and take the minimum of three measurements, since the OS and the rest of the processes on my machine can make the program run slower, but they can never make it run faster than under ideal circumstances. Moreover, since this test is fully deterministic, there is no need to average over different inputs.

So we’ll go along citing the results of this stupid benchmark, and then at the end we’ll present the overall graphs of all the implementations as benchmarked via criterion.

Baseline

First, let’s set our expectations: how fast can we possibly be?

There are two ways to estimate this: theoretical and practical.

Theoretically

Let’s assume that all of our data fits the L2, but gets evicted from L1. Is this justified? Let’s look at what data do we have to keep around:

How much data do we have to pass to/from L2?

We do roughly 20000 iterations, each one reads the whole string (20 KB) and the whole array built on the previous step (160 KB), while writing back the new version of the array (also 160 KB). Focusing just on the reads, we need to read roughly (20 + 160)×20000 KB from L2 total, which is 3.6 GB. The Haswell CPU that I’m using has established L2 reads bandwidth of around 50 GB/s, which gives something of the order of 0.1 s as an absolute minimum on the run time per single comparison.

Since we’re doing two comparisons, the absolute minimum would be about 0.2 s.

This makes a ton of assumptions, like infinite CPU performance (so we’re bottlenecked on memory) or lack of interference between simultaneous reads and writes.
Practically

Let’s throw up this little C++ program as practical baseline:

#include <algorithm>
#include <iostream>
#include <numeric>
#include <vector>
#include <string>

size_t lev_dist(const std::string& s1, const std::string& s2)
{
  const auto m = s1.size();
  const auto n = s2.size();

  std::vector<int64_t> v0;
  v0.resize(n + 1);
  std::iota(v0.begin(), v0.end(), 0);

  auto v1 = v0;

  auto s1char = s1[i];
  for (size_t j = 0; j < n; ++j)
  {
    auto delCost = v0[j + 1] + 1;
    auto insCost = v1[j] + 1;
    int substCost = s1char != s2[j];

    v1[j + 1] = std::min({ v0[j] + substCost, delCost, insCost });
  }

  return v0[n];
}

int main()
{
    std::string s1(20000, 'a');
    std::string s2(20000, 'a');
    std::string s3(20000, 'b');

    std::cout << lev_dist(s1, s2) << std::endl;
    std::cout << lev_dist(s1, s3) << std::endl;
}

Its run time is around 1.3 seconds.

More precisely, the minimum run time across three runs is 1.356 seconds (Core i7 4770, no overclocking). For reproducibility, the C++ compiler used is gcc 9.2, compiling with -O3 -march=native.

Note that this function uses initializer-list-based std::min, but we’ll get to that (and some other C++-related things) later.

So we’re going to use 1.3 seconds as our goal.

The code

Pure implementation

No explicit mutability besides what’s in Data.Vector, but it’s arguable if we should consider that. No explicit tail recursion. Just a strict left fold and Data.Vector.constructN:

import qualified Data.ByteString as BS
import qualified Data.Vector.Unboxed as V
import Data.List

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = foldl' outer (V.generate (n + 1) id) [0 .. m - 1] V.! n
  where
    m = BS.length s1
    n = BS.length s2

    outer v0 i = V.constructN (n + 1) ctr
      where
        s1char = s1 `BS.index` i
        ctr v1 | V.length v1 == 0 = i + 1
        ctr v1 = min (substCost + substCostBase) $ 1 + min delCost insCost
          where
            j = V.length v1
            delCost = v0 V.! j
            insCost = v1 V.! (j - 1)
            substCostBase = v0 V.! (j - 1)
            substCost = if s1char == s2 `BS.index` (j - 1) then 0 else 1

That might seem a bit ugly, but that’s still reasonably nice and idiomatic (as much as such algorithms can be expressed idiomatically at all).

The run time with the sloppy benchmark is about 5.5 seconds. Ugh. Not nice.

Can we do better?

Cheap improvements

There are two ways to improve the run time without even touching the code of the function.

First one is to do add strictness annotations since compiler is unlikely to have figured them all out. We’re going the lazy path and just adding {-# LANGUAGE Strict #-} to the top of the file.

The result? 3.4 seconds. Better, but not that much.

But what if we compile our code via LLVM in addition to the Strict pragma? Let’s add {-# OPTIONS_GHC -fllvm #-} to the top of the file!

The result? 2.1 seconds! That’s a huge performance improvement for a code generator change, and starts getting close to the C++ version.

Pure with unsafe

Now, let’s try a less cheap improvement: we’ll replace indexing operations with their unsafe equivalents. So V.! gets translated to V.unsafeIndex, and BS.index also gets replaced by BS.unsafeIndex.

Full code
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Vector.Unboxed as V
import Data.List

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = foldl' outer (V.generate (n + 1) id) [0 .. m - 1] V.! n
  where
    m = BS.length s1
    n = BS.length s2

    outer v0 i = V.constructN (n + 1) ctr
      where
        s1char = s1 `BS.unsafeIndex` i
        ctr v1 | V.length v1 == 0 = i + 1
        ctr v1 = min (substCost + substCostBase) $ 1 + min delCost insCost
          where
            j = V.length v1
            delCost = v0 `V.unsafeIndex` j
            insCost = v1 `V.unsafeIndex` (j - 1)
            substCostBase = v0 `V.unsafeIndex` (j - 1)
            substCost = if s1char == s2 `BS.unsafeIndex` (j - 1) then 0 else 1

Interestingly, in an ideal world of Dependent Haskell we wouldn’t have needed to do this: we could have proved statically that the index is always safe, and no runtime checks would be needed.

Anyway, the run time gets to 4.2 seconds without the cheap improvements (compare that to 5.5 seconds of the “safe” version).

What about the ways to improve we considered earlier?

Adding {-# LANGUAGE Strict #-} gets us down to 2.5 seconds (compare that to the 3.4 seconds of the corresponding “safe” version).

Compiling via LLVM continues the trend and gets us to 1.6 seconds (vs 2.1 seconds we had before).

These results are interesting for a number of reasons:

  1. We’re already getting almost C++-like performance while still having quite an idiomatic code.
  2. We’re getting such performance despite allocating memory on each iteration (20000 times more than the C++ version does). Luckily, the corresponding vector isn’t long-lived, so it dies and gets collected out of the nursery — generational GCs rule!
  3. The “safety penalty” of both the strict and non-strict versions compiled with GHC’s NCG seems to be on the order of 0.9-1.3 seconds, or something like 20%. Personally, I’d expect them to be considerably smaller, as I’d expect the branch predictor to learn the right branch fairly quickly, as those checks never fire.
  4. The LLVM backend reduces this penalty to about 0.5 seconds. Is it able to elide some of the checks, or do the checks become more branch predictor-friendly? Or is it the combination of both?

Mutable implementation

Let’s now start with a direct rewrite of the algorithm with a mutable Data.Array:

import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString.Char8 as BS
import Control.Monad
import Control.Monad.ST

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
  v0Init <- A.newListArray (0, n) [0..]
  v1Init <- A.newArray_ (0, n)
  forM_ [0 .. m - 1] $ \i -> do
    let (v0, v1) | even i = (v0Init, v1Init)
                 | otherwise = (v1Init, v0Init)
    loop i v0 v1
  A.unsafeRead (if even m then v0Init else v1Init) n

  where
    m = BS.length s1
    n = BS.length s2

    loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
    loop i v0 v1 = do
      A.unsafeWrite v1 0 (i + 1)
      let s1char = s1 `BS.index` i
      forM_ [0..n - 1] $ \j -> do
        delCost <- v0 `A.unsafeRead` (j + 1)
        insCost <- v1 `A.unsafeRead` j
        substCostBase <- v0 `A.unsafeRead` j
        let substCost = if s1char == s2 `BS.index` j then 0 else 1
        A.unsafeWrite v1 (j + 1) $ min (substCost + substCostBase) $ 1 + min delCost insCost

In my experience mutable arrays turn out to be faster than mutable vectors, so we’re not exploring the Data.Vector.Mutable family to keep this post’s length reasonable. For the very same reason we’re not considering the safe counterparts to the array indexing operations.

How fast does this turn out to be?

5.9 seconds. Dang. Way worse than the safe pure implementation! And even if we profile this code, well… At least I wasn’t able to derive anything reasonable from the profile.

On the other hand, at least this implementation doesn’t cause any minor GC’s: it’s 2-3 minor ones and 2-3 major ones according to +RTS -sstderr, as opposed to thousands of minor GCs with the Data.Vector. It doesn’t help much in terms of performance, though (yes, generational GCs rule, live fast, die young).

What about our good old friends? Strictness everywhere reduces this to 4.1 seconds, and LLVM gains another 0.4 seconds, bringing us to 3.7 seconds.

No, thank you, I’ll keep the pure one.

But let’s keep going.

Explicit tail recursion

There are numerous hints that GHC really loves tail recursion, so let’s rewrite our code to avoid the forM_ combinator and use tail recursion explicitly:

import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString.Char8 as BS
import Control.Monad.ST

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
  v0Init <- A.newListArray (0, n) [0..]
  v1Init <- A.newArray_ (0, n)
  loop 0 v0Init v1Init
  A.unsafeRead (if even m then v0Init else v1Init) n

  where
    m = BS.length s1
    n = BS.length s2

    loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
    loop i v0 v1 | i == m = pure ()
                 | otherwise = do
      A.unsafeWrite v1 0 (i + 1)
      let s1char = s1 `BS.index` i
      let go j | j == n = pure ()
               | otherwise = do
            delCost <- v0 `A.unsafeRead` (j + 1)
            insCost <- v1 `A.unsafeRead` j
            substCostBase <- v0 `A.unsafeRead` j
            let substCost = if s1char == s2 `BS.index` j then 0 else 1
            A.unsafeWrite v1 (j + 1) $ min (substCost + substCostBase) $ 1 + min delCost insCost
            go (j + 1)
      go 0
      loop (i + 1) v1 v0

This gets us to 4.3 seconds. Making all this strict further gets us down to 1.7 seconds — just as with unsafe vectors. Frankly I’m surprised GHC wasn’t able to optimize forM_ as nicely as it does with tail recursion.

But hold your breath. Let’s build with LLVM now.

0.96 seconds. Wat? Faster than C++? Did I make a mistake somewhere?

0.96 seconds.

Wow.

But we can do better.

Big little things

We’re still doing safe string reads. Let’s fix that!

Replacing BS.index with BS.unsafeIndex chops off another 0.04 seconds, getting us to 0.92 seconds (BTW I’m only considering strictness enabled and LLVM backend from now on).

There’s also another optimization we could try. Note that we write to v1[j+1] on jth iteration and read from the very same array cell the next iteration. What if we pass the corresponding value straight to the tail-recursive call?

Here’s the code.
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -fllvm #-}

import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import Control.Monad.ST

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
  v0Init <- A.newListArray (0, n) [0..]
  v1Init <- A.newArray_ (0, n)
  loop 0 v0Init v1Init
  A.unsafeRead (if even m then v0Init else v1Init) n

  where
    m = BS.length s1
    n = BS.length s2

    loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
    loop i v0 v1 | i == m = pure ()
                 | otherwise = do
      A.unsafeWrite v1 0 (i + 1)
      let s1char = s1 `BS.unsafeIndex` i
      let go j prev | j == n = pure ()
                    | otherwise = do
            delCost <- v0 `A.unsafeRead` (j + 1)
            substCostBase <- v0 `A.unsafeRead` j
            let substCost = if s1char == s2 `BS.unsafeIndex` j then 0 else 1
            let res = min (substCost + substCostBase) $ 1 + min delCost prev
            A.unsafeWrite v1 (j + 1) res
            go (j + 1) res
      go 0 (i + 1)
      loop (i + 1) v1 v0

This micro-optimization gives us roughly 0.09 seconds, reducing our run time to 0.83 seconds.

It’s also worth noting that the C++ version doesn’t benefit from the analogous optimization, so perhaps the compiler is doing it for us already (and, in fact, I’ve heard this optimization is indeed implemented in the recent versions of gcc and clang).

Anyway, I think that’s good enough (and it’s actually not that ugly). Let’s summarize the results.

Summary

So here’s how it all looks together.

Sloppy benchmarks

First, the tabular form of the sloppy benchmark.

Here, the numbers are normalized so that 100% is the C++ version.

Version Base time + strictness + LLVM
Pure 406% 250% 154%
Pure (unsafe indexing) 309% 181% 121%
Mutable array 435% 298% 274%
Array + tail rec 318% 129% 70%
Array + tail rec + microopts 61%

Base time for the C++ version here is 1.36 seconds.

I’ve also had a chance to run a couple of tests on an i7-6700 machine, and the two most important data points are:

Looks like Haswell is closer to Haskell than Skylake, although the fastest Haskell version is still faster than C++ on the latter.

Criterion results

Here are the criterion graphs (BTW it’s live, so feel free to hover, scroll and drag around).

One for Haswell (i7 4770):
One for Skylake (i7 6700):

Both machines have the same gcc versions (namely, 9.2), and they both use the same ghc (8.6). Again, Skylake is way more C++-friendly, but even there the best Haskell version wins.

Note that criterion shows the averages, not the minimum times (as might have been more suitable in this case).

More on the C++ version

After publishing the initial version of this post I have received some feedback on the C++ version of the code and some possible optimizations. Although this would deviate slightly from the subject of optimizing Haskell code, let’s be good guys and go for a full disclosure. So, let’s see how it affects gcc (version 9.2) and clang (version 9.0.1).

The suggestion is to replace std::min using the initializer list with a couple of nested std::mins, so that the contents of the inner loop

    auto delCost = v0[j + 1] + 1;
    auto insCost = v1[j] + 1;
    int substCost = s1char != s2[j];

    v1[j + 1] = std::min({ v0[j] + substCost, delCost, insCost });

become

    auto delCost = v0[j + 1] + 1;
    auto insCost = v1[j] + 1;
    int substCost = s1char != s2[j];

    v1[j + 1] = std::min(v0[j] + substCost, std::min(delCost, insCost));

This has no effect on gcc, but clang starts to shine: its run time gets from 1.5 seconds to 0.840 s — very close to our Haskell version!

This also allows one more optimization opportunity: namely, adding 1 just once via changing the code to

    auto delCost = v0[j + 1];
    auto insCost = v1[j];
    int substCost = s1char != s2[j];

    v1[j + 1] = std::min(v0[j] + substCost, 1 + std::min(delCost, insCost));

This has twofold effect on the performance. On one hand, clang shines more, getting a tad ahead of the Haskell version with 0.820 s of run time (although that’s somewhat within measure error). On the other hand, gcc slows down, going from 1.36 s to 1.45 s. Shall we have two #ifdef branches if it were production code?

Anyway, firstly, I’m surprised that replacing the initializer-list-based std::min version has such effect on performanced for clang, especially given that clang is perfectly able to optimize it away in simpler cases.

Secondly, one might shave a few milliseconds off the Haskell version too (replacing $ with $!, for instance, or explicitly unfolding the definition of min), but I’m not sure it’s worth it in the context of this benchmark.

Conclusion

Miscellaneous

What did I omit in this post?

All in all, take this post with a grain of salt.

References

Edit history