Writing a fast edit distance implementation
In this post we will implement the Levenshtein edit distance algorithm, aiming at a reasonably performant code. We will start with a more or less idiomatic pure implementation and see how various changes (including strictness annotations or compilation options) affect the performance.
We will also compare this to a baseline C++ implementation.
Spoiler alerts:
- C++ implementation turns out to be slower than the fastest Haskell implementation.
- LLVM backend really shines here.
Long story short, the Levenshtein algorithm takes two strings and computes how many characters have to be inserted, deleted or changed in order to make those strings match.
We will follow what’s known as the dynamic programming approach with two matrix rows,
having O(mn)
time complexity and O(min(m, n))
space complexity, where m
and n
are the lengths of the respective input strings.
Motivation
There are already a few implementations of the edit distance, namely the one in the text-metrics package as well as the ultra-fast edit-distance.
There were some reasons for me to write yet another implementation besides the obvious one:
text-metrics
works withData.Text
, so there must be some performance penalty related to dealing with Unicode. In short, I didn’t care about non-ASCII text, so I didn’t need that penalty.edit-distance
uses a squareSTUArray
(a full-matrix approach), thus having a different space complexity. I’m sort of ashamed to admit that the previous sentence is actually incorrect, and it’s just me reading the code wrongly — in fact the underlying algorithm there is way, way more complex.- And, most importantly, writing code is fun, trying to make code fast is twice as fun, learning from that experience is infinitely fun!
Anyway, here we go.
Benchmarking
We will have two ways of benchmarking this. The first one is with the criterion package that we all know and love. But running criterion is slow, so we’ll also have a really stupid benchmark that generates three strings:
s1
of 20000 charactersa
s2
of 20000 charactersa
s3
of 20000 charactersb
and then computes the distance from s1
to s2
(which is obviously 0
) and the distance from s1
to s3
(which is 20000
).
Since preparing the benchmark data is dwarfed by the run time, we will just measure how long the program runs (via +RTS -sstderr
)
and take the minimum of three measurements, since the OS and the rest of the processes on my machine can make the program run slower,
but they can never make it run faster than under ideal circumstances.
Moreover, since this test is fully deterministic, there is no need to average over different inputs.
So we’ll go along citing the results of this stupid benchmark, and then at the end we’ll present the overall graphs of all the implementations as benchmarked via criterion.
Baseline
First, let’s set our expectations: how fast can we possibly be?
There are two ways to estimate this: theoretical and practical.
Theoretically
Let’s assume that all of our data fits the L2, but gets evicted from L1. Is this justified? Let’s look at what data do we have to keep around:
- Two strings of 20000 bytes: roughly 40 KB (already more than L1 on my machine!).
- Two arrays of 20000 8-byte integers, which gives us two times 160 KB (and even if we used 4-byte integers, it’d be two times 80 KB, which still will get evicted from L1).
How much data do we have to pass to/from L2?
We do roughly 20000 iterations, each one reads the whole string (20 KB) and the whole array built on the previous step (160 KB), while writing back the new version of the array (also 160 KB). Focusing just on the reads, we need to read roughly (20 + 160)×20000 KB from L2 total, which is 3.6 GB. The Haswell CPU that I’m using has established L2 reads bandwidth of around 50 GB/s, which gives something of the order of 0.1 s as an absolute minimum on the run time per single comparison.
Since we’re doing two comparisons, the absolute minimum would be about 0.2 s.
This makes a ton of assumptions, like infinite CPU performance (so we’re bottlenecked on memory) or lack of interference between simultaneous reads and writes.Practically
Let’s throw up this little C++ program as practical baseline:
#include <algorithm>
#include <iostream>
#include <numeric>
#include <vector>
#include <string>
size_t lev_dist(const std::string& s1, const std::string& s2)
{
const auto m = s1.size();
const auto n = s2.size();
std::vector<int64_t> v0;
v0.resize(n + 1);
std::iota(v0.begin(), v0.end(), 0);
auto v1 = v0;
auto s1char = s1[i];
for (size_t j = 0; j < n; ++j)
{
auto delCost = v0[j + 1] + 1;
auto insCost = v1[j] + 1;
int substCost = s1char != s2[j];
v1[j + 1] = std::min({ v0[j] + substCost, delCost, insCost });
}
return v0[n];
}
int main()
{
std::string s1(20000, 'a');
std::string s2(20000, 'a');
std::string s3(20000, 'b');
std::cout << lev_dist(s1, s2) << std::endl;
std::cout << lev_dist(s1, s3) << std::endl;
}
Its run time is around 1.3 seconds.
More precisely, the minimum run time across three runs is 1.356 seconds (Core i7 4770, no overclocking).
For reproducibility, the C++ compiler used is gcc 9.2, compiling with -O3 -march=native
.
std::min
, but we’ll get to that (and some other C++-related things) later.
So we’re going to use 1.3 seconds as our goal.
The code
Pure implementation
No explicit mutability besides what’s in Data.Vector
, but it’s arguable if we should consider that.
No explicit tail recursion.
Just a strict left fold and Data.Vector.constructN
:
import qualified Data.ByteString as BS
import qualified Data.Vector.Unboxed as V
import Data.List
levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = foldl' outer (V.generate (n + 1) id) [0 .. m - 1] V.! n
where
m = BS.length s1
n = BS.length s2
outer v0 i = V.constructN (n + 1) ctr
where
s1char = s1 `BS.index` i
ctr v1 | V.length v1 == 0 = i + 1
ctr v1 = min (substCost + substCostBase) $ 1 + min delCost insCost
where
j = V.length v1
delCost = v0 V.! j
insCost = v1 V.! (j - 1)
substCostBase = v0 V.! (j - 1)
substCost = if s1char == s2 `BS.index` (j - 1) then 0 else 1
That might seem a bit ugly, but that’s still reasonably nice and idiomatic (as much as such algorithms can be expressed idiomatically at all).
The run time with the sloppy benchmark is about 5.5 seconds. Ugh. Not nice.
Can we do better?
Cheap improvements
There are two ways to improve the run time without even touching the code of the function.
First one is to do add strictness annotations since compiler is unlikely to have figured them all out.
We’re going the lazy path and just adding {-# LANGUAGE Strict #-}
to the top of the file.
The result? 3.4 seconds. Better, but not that much.
But what if we compile our code via LLVM in addition to the Strict
pragma?
Let’s add {-# OPTIONS_GHC -fllvm #-}
to the top of the file!
The result? 2.1 seconds! That’s a huge performance improvement for a code generator change, and starts getting close to the C++ version.
Pure with unsafe
Now, let’s try a less cheap improvement: we’ll replace indexing operations with their unsafe equivalents.
So V.!
gets translated to V.unsafeIndex
,
and BS.index
also gets replaced by BS.unsafeIndex
.
Full code
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Vector.Unboxed as V
import Data.List
levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = foldl' outer (V.generate (n + 1) id) [0 .. m - 1] V.! n
where
m = BS.length s1
n = BS.length s2
outer v0 i = V.constructN (n + 1) ctr
where
s1char = s1 `BS.unsafeIndex` i
ctr v1 | V.length v1 == 0 = i + 1
ctr v1 = min (substCost + substCostBase) $ 1 + min delCost insCost
where
j = V.length v1
delCost = v0 `V.unsafeIndex` j
insCost = v1 `V.unsafeIndex` (j - 1)
substCostBase = v0 `V.unsafeIndex` (j - 1)
substCost = if s1char == s2 `BS.unsafeIndex` (j - 1) then 0 else 1
Interestingly, in an ideal world of Dependent Haskell we wouldn’t have needed to do this: we could have proved statically that the index is always safe, and no runtime checks would be needed.
Anyway, the run time gets to 4.2 seconds without the cheap improvements (compare that to 5.5 seconds of the “safe” version).
What about the ways to improve we considered earlier?
Adding {-# LANGUAGE Strict #-}
gets us down to 2.5 seconds (compare that to the 3.4 seconds of the corresponding “safe” version).
Compiling via LLVM continues the trend and gets us to 1.6 seconds (vs 2.1 seconds we had before).
These results are interesting for a number of reasons:
- We’re already getting almost C++-like performance while still having quite an idiomatic code.
- We’re getting such performance despite allocating memory on each iteration (20000 times more than the C++ version does). Luckily, the corresponding vector isn’t long-lived, so it dies and gets collected out of the nursery — generational GCs rule!
- The “safety penalty” of both the strict and non-strict versions compiled with GHC’s NCG seems to be on the order of 0.9-1.3 seconds, or something like 20%. Personally, I’d expect them to be considerably smaller, as I’d expect the branch predictor to learn the right branch fairly quickly, as those checks never fire.
- The LLVM backend reduces this penalty to about 0.5 seconds. Is it able to elide some of the checks, or do the checks become more branch predictor-friendly? Or is it the combination of both?
Mutable implementation
Let’s now start with a direct rewrite of the algorithm with a mutable Data.Array
:
import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString.Char8 as BS
import Control.Monad
import Control.Monad.ST
levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
v0Init <- A.newListArray (0, n) [0..]
v1Init <- A.newArray_ (0, n)
forM_ [0 .. m - 1] $ \i -> do
let (v0, v1) | even i = (v0Init, v1Init)
| otherwise = (v1Init, v0Init)
loop i v0 v1
A.unsafeRead (if even m then v0Init else v1Init) n
where
m = BS.length s1
n = BS.length s2
loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
loop i v0 v1 = do
A.unsafeWrite v1 0 (i + 1)
let s1char = s1 `BS.index` i
forM_ [0..n - 1] $ \j -> do
delCost <- v0 `A.unsafeRead` (j + 1)
insCost <- v1 `A.unsafeRead` j
substCostBase <- v0 `A.unsafeRead` j
let substCost = if s1char == s2 `BS.index` j then 0 else 1
A.unsafeWrite v1 (j + 1) $ min (substCost + substCostBase) $ 1 + min delCost insCost
In my experience mutable arrays turn out to be faster than mutable vectors,
so we’re not exploring the Data.Vector.Mutable
family to keep this post’s length reasonable.
For the very same reason we’re not considering the safe counterparts to the array indexing operations.
How fast does this turn out to be?
5.9 seconds. Dang. Way worse than the safe pure implementation! And even if we profile this code, well… At least I wasn’t able to derive anything reasonable from the profile.
On the other hand, at least this implementation doesn’t cause any minor GC’s:
it’s 2-3 minor ones and 2-3 major ones according to +RTS -sstderr
,
as opposed to thousands of minor GCs with the Data.Vector
.
It doesn’t help much in terms of performance, though (yes, generational GCs rule, live fast, die young).
What about our good old friends? Strictness everywhere reduces this to 4.1 seconds, and LLVM gains another 0.4 seconds, bringing us to 3.7 seconds.
No, thank you, I’ll keep the pure one.
But let’s keep going.
Explicit tail recursion
There are numerous hints that GHC really loves tail recursion,
so let’s rewrite our code to avoid the forM_
combinator and use tail recursion explicitly:
import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString.Char8 as BS
import Control.Monad.ST
levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
v0Init <- A.newListArray (0, n) [0..]
v1Init <- A.newArray_ (0, n)
loop 0 v0Init v1Init
A.unsafeRead (if even m then v0Init else v1Init) n
where
m = BS.length s1
n = BS.length s2
loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
loop i v0 v1 | i == m = pure ()
| otherwise = do
A.unsafeWrite v1 0 (i + 1)
let s1char = s1 `BS.index` i
let go j | j == n = pure ()
| otherwise = do
delCost <- v0 `A.unsafeRead` (j + 1)
insCost <- v1 `A.unsafeRead` j
substCostBase <- v0 `A.unsafeRead` j
let substCost = if s1char == s2 `BS.index` j then 0 else 1
A.unsafeWrite v1 (j + 1) $ min (substCost + substCostBase) $ 1 + min delCost insCost
go (j + 1)
go 0
loop (i + 1) v1 v0
This gets us to 4.3 seconds. Making all this strict further gets us down to 1.7 seconds — just as with unsafe vectors.
Frankly I’m surprised GHC wasn’t able to optimize forM_
as nicely as it does with tail recursion.
But hold your breath. Let’s build with LLVM now.
0.96 seconds. Wat? Faster than C++? Did I make a mistake somewhere?
0.96 seconds.
Wow.
But we can do better.
Big little things
We’re still doing safe string reads. Let’s fix that!
Replacing BS.index
with BS.unsafeIndex
chops off another 0.04 seconds,
getting us to 0.92 seconds (BTW I’m only considering strictness enabled and LLVM backend from now on).
There’s also another optimization we could try.
Note that we write to v1[j+1]
on j
th iteration
and read from the very same array cell the next iteration.
What if we pass the corresponding value straight to the tail-recursive call?
Here’s the code.
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -fllvm #-}
import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Unsafe as BS
import Control.Monad.ST
levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
v0Init <- A.newListArray (0, n) [0..]
v1Init <- A.newArray_ (0, n)
loop 0 v0Init v1Init
A.unsafeRead (if even m then v0Init else v1Init) n
where
m = BS.length s1
n = BS.length s2
loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
loop i v0 v1 | i == m = pure ()
| otherwise = do
A.unsafeWrite v1 0 (i + 1)
let s1char = s1 `BS.unsafeIndex` i
let go j prev | j == n = pure ()
| otherwise = do
delCost <- v0 `A.unsafeRead` (j + 1)
substCostBase <- v0 `A.unsafeRead` j
let substCost = if s1char == s2 `BS.unsafeIndex` j then 0 else 1
let res = min (substCost + substCostBase) $ 1 + min delCost prev
A.unsafeWrite v1 (j + 1) res
go (j + 1) res
go 0 (i + 1)
loop (i + 1) v1 v0
This micro-optimization gives us roughly 0.09 seconds, reducing our run time to 0.83 seconds.
It’s also worth noting that the C++ version doesn’t benefit from the analogous optimization, so perhaps the compiler is doing it for us already (and, in fact, I’ve heard this optimization is indeed implemented in the recent versions of gcc and clang).
Anyway, I think that’s good enough (and it’s actually not that ugly). Let’s summarize the results.
Summary
So here’s how it all looks together.
Sloppy benchmarks
First, the tabular form of the sloppy benchmark.
Here, the numbers are normalized so that 100% is the C++ version.
Version | Base time | + strictness | + LLVM |
---|---|---|---|
Pure | 406% | 250% | 154% |
Pure (unsafe indexing) | 309% | 181% | 121% |
Mutable array | 435% | 298% | 274% |
Array + tail rec | 318% | 129% | 70% |
Array + tail rec + microopts | — | — | 61% |
Base time for the C++ version here is 1.36 seconds.
I’ve also had a chance to run a couple of tests on an i7-6700 machine, and the two most important data points are:
- 178% for the pure version with unsafe indexing, strictness and LLVM.
- 84% for the best performing version.
Looks like Haswell is closer to Haskell than Skylake, although the fastest Haskell version is still faster than C++ on the latter.
Criterion results
Here are the criterion graphs (BTW it’s live, so feel free to hover, scroll and drag around).
One for Haswell (i7 4770):Both machines have the same gcc versions (namely, 9.2), and they both use the same ghc (8.6). Again, Skylake is way more C++-friendly, but even there the best Haskell version wins.
Note that criterion shows the averages, not the minimum times (as might have been more suitable in this case).
More on the C++ version
After publishing the initial version of this post I have received some feedback on the C++ version of the code and some possible optimizations. Although this would deviate slightly from the subject of optimizing Haskell code, let’s be good guys and go for a full disclosure. So, let’s see how it affects gcc (version 9.2) and clang (version 9.0.1).
The suggestion is to replace std::min
using the initializer list with a couple of nested std::min
s,
so that the contents of the inner loop
auto delCost = v0[j + 1] + 1;
auto insCost = v1[j] + 1;
int substCost = s1char != s2[j];
v1[j + 1] = std::min({ v0[j] + substCost, delCost, insCost });
become
auto delCost = v0[j + 1] + 1;
auto insCost = v1[j] + 1;
int substCost = s1char != s2[j];
v1[j + 1] = std::min(v0[j] + substCost, std::min(delCost, insCost));
This has no effect on gcc, but clang starts to shine: its run time gets from 1.5 seconds to 0.840 s — very close to our Haskell version!
This also allows one more optimization opportunity: namely, adding 1
just once via changing the code to
auto delCost = v0[j + 1];
auto insCost = v1[j];
int substCost = s1char != s2[j];
v1[j + 1] = std::min(v0[j] + substCost, 1 + std::min(delCost, insCost));
This has twofold effect on the performance.
On one hand, clang shines more, getting a tad ahead of the Haskell version
with 0.820 s of run time (although that’s somewhat within measure error).
On the other hand, gcc slows down, going from 1.36 s to 1.45 s.
Shall we have two #ifdef
branches if it were production code?
Anyway, firstly, I’m surprised that
replacing the initializer-list-based std::min
version has such effect on performanced for clang,
especially given that clang is perfectly able to optimize it away in simpler cases.
Secondly, one might shave a few milliseconds off the Haskell version too
(replacing $
with $!
, for instance, or explicitly unfolding the definition of min
),
but I’m not sure it’s worth it in the context of this benchmark.
Conclusion
- It’s certainly possible to get within a reasonable margin of the C++ performance with a more or less idiomatic, pure code.
- Strictness (or careful consideration thereof) matters.
- GHC isn’t that great at seeing that certain things can be expressed in a tail-recursive manner. It’s definitely worth it spending some time rewriting computationally heavy algorithms in this style.
- LLVM matters. It’s not something I’ve seen often, even for some of the rather heavy number-crunching applications.
- Code optimization is a stictly non-linear process.
One cannot expect the same gain from the same changes (even “technical” ones like strictness or code generator)
for different algorithms.
In particular, we could have considered just not going theData.Array
route after seeing how badly the base version performs compared to the pure version, but turns out that is the actual winner after some minor tweaks.
Miscellaneous
What did I omit in this post?
- Firstly, on the meta-reasoning level, this can be considered a success story post. Of course, it does not imply that Haskell can always be optimized to match C++ performance. It just means that I was lucky enough to stumble upon a problem where it performs well. If it didn’t perform so well, this post wouldn’t have been this enthusiastic (or wouldn’t have been at all).
- Importing the right
ByteString
matters.Data.ByteString.Char8
is marginally but statistically significantly slower thanData.ByteString
. - GHC version matters. Exploring this is a whole different dimension, though.
The code above has been compiled with GHC 8.6 (Stackage LTS 14.16). GHC 8.8 (with some nightly Stackage snapshot) seems to produce better code with NCG, but it gets worse with LLVM. It doesn’t affect at all the best-performing variant above, though. - Dependent types matter. I didn’t omit it in the post, but I can’t help but emphasize this again.
Having to write
unsafe
makes me cringe a little. Especially when a sufficiently expressive language enables me to prove that it’s safe. - Another dimension worth exploring is the behaviour on a wider set of inputs. Comparing two equal (or two completely different) strings doesn’t sound like a too representative benchmark, but doing this right is another story.
- Microarchitecture matters.
I shall mention I have received a report that the final best performing Haskell version is
“just” on par with the C++ version even when compiled with gcc.
I wasn’t able to reproduce that on any of the hardware available to me,
and that report came from a Ryzen machine, so I blame AMD.
And, of course, this is obviously still great. - Our baseline C++ version can surely be optimized further, but that probably requires way more time that I have spent on optimizing Haskell.
- Input strings matter. I did a few test runs on random strings (of the same size as the strings considered in this post, of course), and both the C++ and the Haskell versions get slower (Haskell version — a bit more). But measuring anything involving randomness is nontrivial, more so comparing across different languages with different random number generators.
- …and the rest of the things I omitted in this post.
All in all, take this post with a grain of salt.
References
Edit history
- Jan 20: added a section about C++ code variations and their effects on performance.