字符串匹配算法

我们经常需要确定一个较短的字符串(模式串)在一个较长字符串(主串)中是否出现或者出现的最小位置。例如 Python 中字符串类型的的 findindex 方法:

s = 'I Love Python'
print(s.find('Py'))  # 7
print(s.find('Pyc')) # -1

index 方法与 find 的唯一区别在于当主串中不存在模式串时会抛出 ValueError

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: substring not found

关于 Python(CPython) 中的字符串匹配算法后面会提到,先来看一下直观的算法和知名的 KMP 算法。

1. 顺序匹配算法

这是最直观的算法,就是将模式串沿着主串从左向右滑动,直到找到主串中与之相匹配的子字符串,并返回其位置。为了演示这一算法我们先构建一个虚拟的字符串类型:

class String:
  def __init__(self, ss="", length=0):
        self.string = ss
        self.length = length or len(ss)
    def __getitem__(self, i):
        return self.string[i]

这个字符串类型通过数组存储连续的字符,并保留了字符串的长度值。

def find(S, T, pos = 0):
  i = pos
  j = 0
  while i < S.length and j < T.length:
    if S[i] == T[j]:
      i += 1
      j += 1
    else:
      i = i - j + 1
      j = 0
  if j >= T.length:
    return i - T.length
  return -1

分析一下这个算法可以发现,在 find(String('I Love Python'), String('Py')) 的例子中,时间复杂度为 O(n+m),因为主串的下标是一直向前移动的。但是对于 find(String('PPPPPPPython'), String('Py')) 这样的情况,每当遇到模式串的第二个字符不匹配时,即 P != yij 都需要退回去重新向前移动,这就会导致复杂度变为 O(n*m)。KMP 算法就是针对这种情况的改进。

2. KMP 算法

这一改进算法由 Knuth、Pratt、Morris 同时发现,故得名 KMP。改进的原理也比较简单,我们希望可以尽量减少由于模式串匹配到一半发现不匹配时所导致的 ij 退回的步数,例如直到模式串的第 k 个字符才发现 S[k] != T[k],但是 k 之前已经比较过的字符串我们不希望白白浪费掉,然后回过头来重复比较一次,而且这写比较过的字符长度越长越好。例如模式串为:T = String('ABABC'),而主串为S = String('ABABDCCC'),这是如果 k = 4 ,即 T[4] != S[4],但此时 T[2:4] = 'AB' 恰好等于 T[0:2] = 'AB',也就是说不需要再从头比较一次 S[2:4] == T[0:2],因为 T = String('ABABC') 本身的性质已经决定了它们一定是相等的(否则也不会一直匹配到 k=4 才出现不匹配)。

总结来说就是,我们利用模式串中 T[0:k] == T[m-k:m] 的性质(如果存在的话),在字符串比较的时候可以省略一定的步数从而减少不必要的重复比较。虽然这需要我们付出额外的时间去检验模式串的这一性质,但由于模式串的长度往往小于主串,这样的付出还是值得的。然而如果模式串根本不具备这样的性质,例如完全是由不同字符组成的 String('ABCDE'),那么 KMP 算法反而增加了复杂度。

为了获取模式串的 KMP 性质,我们需要一个额外的数组来记录当第 j 个字符与主串不匹配时,我们可以跳过模式串的前 k 个字符,这个数组满足:

def nxt(T, j):
  if j == 0:
    return -1
  if T[0:k] == T[j-k:j]:
    return max(k)
  else:
    return 0

可以将长度为 m 的模式串 T 的每一个 nxt(T,j) 保存在一个数组中:

def kmp_next(T):
  nxt = [-1] * T.length
  i   = 0
  j   = -1
  while i < T.length:
    if j == -1 or T[j] == T[i]:
      i += 1
      j += 1
      if i < T.length:
        nxt[i] = j
    else:
      j = nxt[j]
  return nxt

再来完成 KMP 算法只需要对上面的顺序匹配法稍加改动即可:

def KMP(S, T, post = 0):
  nxt = kmp_next(T)
  i   = pos
  j   = 0
  while i < S.length and j < T.length:
    if j == -1 or S[i] == T[i]:
      i += 1
      j += 1
    else:
      j = nxt[j]
  if j >= T.length:
    return i - T.length
  return -1

3. Python 源码中的实现方式

为了探究一下 Python 中字符串匹配算法是什么样的,我去看了一下 GitHub 上的源码,位于 Objects/stringlib/fastsearch.h

根据头部注释的说明:

based on a mix between boyer-moore and horspool,

也就是混合了 B-MHorspool,另外注释中也提供了一篇详细说明的文章地址:The stringlib Library,这里暂时不做深入研究。

4. One More Think

上面都是关于在较长字符串中匹配寻找较短字符串的算法,还有另外一种问题是关于寻找任意两个字符串中的公共子序列,也就是常说的最长公共子序列(Longgest Common Subsequence, LCS)问题。其中这里的子序列是指所有与源字符串中出现顺序相同但不一定位置相同的子字符串,例如 PYTPYO 都是 PYTHON 的子序列。

这一问题的暴力解法复杂度相当可怕,因为每个长度为 m 的字符串共有 2^m 个子序列,因此一般采用动态规划(Dynamic Programming)的算法来解决。首先需要构造 LCS 问题的最优子结构:

设定字符串 X = [x1,x2,...,xm] 的第 i 个前缀为 Xi = [x1, x2,...,xi]X0 为空;假设两个字符串 X = [x1, x2,...,xm]Y = [y1, y2,...,yn] 的 LCS 为 Z = [z1, z2,...,zk],则可以将问题分解为:

  1. 如果 xm == yn,则 zk == xm == ynZk-1Xm-1Yn-1 的一个 LCS;
  2. 如果 xm != yn,且 zk != xm,则 ZXm-1Yn 的一个 LCS;
  3. 如果 xm != yn,且 zk != yn,则 ZXmYn-1 的一个 LCS。

由此可以找到 LCS 问题的重叠子问题中的递归解,设定二维数组 subs[i][j] 存储了 XiYj 的 LCS 的长度,则有:

  1. i == 0 or j == 0 时,subs[i][j] = 0
  2. i > 0 and j > 0 and xi == yj 时,subs[i][j] = subs[i-1][j-1] + 1
  3. i > 0 and j > 0 and xi != yj 时,subs[i][j] = max(subs[i][j-1], subs[i-1][j])

转化成代码:

def LCS_lengths(X, Y):
  subs = []

  # 初始化二维表
  for _ in range(X.length + 1): # 长度为 X.length + 1 是为了保存 X0
    subs.append([0] * (Y.length + 1))
  # 这里有一个坑,考虑一下为什么不可以用下面的方式进行初始化?
  # subs = [[0] * (Y.length + 1)] * (X.length + 1)
  for i in range(1, X.length + 1):
    for j in range(1, X.length + 1):
      if X[i-1] == Y[j-1]: # 字符串中下标是从 0 开始的,但这里的 i, j 是从 X1, Y1 开始的
        subs[i][j] = subs[i-1][j-1] + 1
      elif subs[i][j-1] >= subs[i-1][j]:
        subs[i][j] = subs[i][j-1]
      else:
        subs[i][j] = subs[i-1][j]
  return subs

检验一下:

X = String('ABCBDAB')
Y = String('BDCABA')
subs = LCS_lengths(X, Y)
print(subs)
"""
[[0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 1, 1],
 [0, 1, 1, 1, 1, 2, 2],
 [0, 1, 1, 2, 2, 2, 2],
 [0, 1, 1, 2, 2, 3, 3],
 [0, 1, 2, 2, 2, 3, 3],
 [0, 1, 2, 2, 3, 3, 4],
 [0, 1, 2, 2, 3, 4, 4]]
"""

# 最后一个元素就是 Xm 与 Yn 的 LCS 长度:
print(subs[-1][-1])
# 4

上面的方法只帮助我们找到了 LCS 的长度,如果想要一个最常子序列的字符串呢?这时需要在生成 subs 的过程中记录每一次的比较,方便我们进行回溯:

def LCS_lengths(X, Y):
  subs = []
  road_map = []
  for _ in range(X.length + 1):
    subs.append([0] * (Y.length + 1))
    road_map.append([0] * (Y.length + 1))
  for i in range(1, X.length + 1):
    for j in range(1, Y.length + 1):
      if X[i-1] == Y[j-1]:
        subs[i][j] = subs[i-1][j-1] + 1
        road_map[i][j] = 'M' # Match
      elif subs[i][j-1] >= subs[i-1][j]:
        subs[i][j] = subs[i][j-1]
        road_map[i][j] = 'Y' # find from Yj-1
      else:
        subs[i][j] = subs[i-1][j]
        road_map[i][j] = 'X' # find from Xi-1
  return subs, road_map
def LCS_find(X, Y):
  _, road_map = LCS_lengths(X, Y)
  def _find(road, X, i, j, lcs):
    if i == 0 or j == 0:
      return
    if road[i][j] == 'M':
      _find(road, X, i-1, j-1, lcs)
      lcs.append(X[i])
    elif road[i][j] == 'Y':
      _find(road, X, i, j-1, lcs)
    else:
      _find(road, X, i-1, j, lcs)

    lcs = []
    _find(road_map, X, X.length, Y.length, lcs)
    return lcs
print(LCS_find(X, Y))
# ['B', 'C', 'B', 'A']
```%                                                                                       算法与数据结构  pbpaste | pbcopy
  算法与数据结构  pbpaste
# 字符串匹配算法

我们经常需要确定一个较短的字符串(模式串)在一个较长字符串(主串)中是否出现或者出现的最小位置。例如 Python 中字符串类型的的 `find`  `index` 方法:

```python
s = 'I Love Python'
print(s.find('Py'))  # 7
print(s.find('Pyc')) # -1

index 方法与 find 的唯一区别在于当主串中不存在模式串时会抛出 ValueError

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: substring not found

关于 Python(CPython) 中的字符串匹配算法后面会提到,先来看一下直观的算法和知名的 KMP 算法。

1. 顺序匹配算法

这是最直观的算法,就是将模式串沿着主串从左向右滑动,直到找到主串中与之相匹配的子字符串,并返回其位置。为了演示这一算法我们先构建一个虚拟的字符串类型:

class String:
  def __init__(self, ss="", length=0):
        self.string = ss
        self.length = length or len(ss)
    def __getitem__(self, i):
        return self.string[i]

这个字符串类型通过数组存储连续的字符,并保留了字符串的长度值。

def find(S, T, pos = 0):
  i = pos
  j = 0
  while i < S.length and j < T.length:
    if S[i] == T[j]:
      i += 1
      j += 1
    else:
      i = i - j + 1
      j = 0
  if j >= T.length:
    return i - T.length
  return -1

分析一下这个算法可以发现,在 find(String('I Love Python'), String('Py')) 的例子中,时间复杂度为 O(n+m),因为主串的下标是一直向前移动的。但是对于 find(String('PPPPPPPython'), String('Py')) 这样的情况,每当遇到模式串的第二个字符不匹配时,即 P != yij 都需要退回去重新向前移动,这就会导致复杂度变为 O(n*m)。KMP 算法就是针对这种情况的改进。

2. KMP 算法

这一改进算法由 Knuth、Pratt、Morris 同时发现,故得名 KMP。改进的原理也比较简单,我们希望可以尽量减少由于模式串匹配到一半发现不匹配时所导致的 ij 退回的步数,例如直到模式串的第 k 个字符才发现 S[k] != T[k],但是 k 之前已经比较过的字符串我们不希望白白浪费掉,然后回过头来重复比较一次,而且这写比较过的字符长度越长越好。例如模式串为:T = String('ABABC'),而主串为S = String('ABABDCCC'),这是如果 k = 4 ,即 T[4] != S[4],但此时 T[2:4] = 'AB' 恰好等于 T[0:2] = 'AB',也就是说不需要再从头比较一次 S[2:4] == T[0:2],因为 T = String('ABABC') 本身的性质已经决定了它们一定是相等的(否则也不会一直匹配到 k=4 才出现不匹配)。

总结来说就是,我们利用模式串中 T[0:k] == T[m-k:m] 的性质(如果存在的话),在字符串比较的时候可以省略一定的步数从而减少不必要的重复比较。虽然这需要我们付出额外的时间去检验模式串的这一性质,但由于模式串的长度往往小于主串,这样的付出还是值得的。然而如果模式串根本不具备这样的性质,例如完全是由不同字符组成的 String('ABCDE'),那么 KMP 算法反而增加了复杂度。

为了获取模式串的 KMP 性质,我们需要一个额外的数组来记录当第 j 个字符与主串不匹配时,我们可以跳过模式串的前 k 个字符,这个数组满足:

def nxt(T, j):
  if j == 0:
    return -1
  if T[0:k] == T[j-k:j]:
    return max(k)
  else:
    return 0

可以将长度为 m 的模式串 T 的每一个 nxt(T,j) 保存在一个数组中:

def kmp_next(T):
  nxt = [-1] * T.length
  i   = 0
  j   = -1
  while i < T.length:
    if j == -1 or T[j] == T[i]:
      i += 1
      j += 1
      if i < T.length:
        nxt[i] = j
    else:
      j = nxt[j]
  return nxt

再来完成 KMP 算法只需要对上面的顺序匹配法稍加改动即可:

def KMP(S, T, post = 0):
  nxt = kmp_next(T)
  i   = pos
  j   = 0
  while i < S.length and j < T.length:
    if j == -1 or S[i] == T[i]:
      i += 1
      j += 1
    else:
      j = nxt[j]
  if j >= T.length:
    return i - T.length
  return -1

3. Python 源码中的实现方式

为了探究一下 Python 中字符串匹配算法是什么样的,我去看了一下 GitHub 上的源码,位于 Objects/stringlib/fastsearch.h

根据头部注释的说明:

based on a mix between boyer-moore and horspool,

也就是混合了 B-MHorspool,另外注释中也提供了一篇详细说明的文章地址:The stringlib Library,这里暂时不做深入研究。

4. One More Think

上面都是关于在较长字符串中匹配寻找较短字符串的算法,还有另外一种问题是关于寻找任意两个字符串中的公共子序列,也就是常说的最长公共子序列(Longgest Common Subsequence, LCS)问题。其中这里的子序列是指所有与源字符串中出现顺序相同但不一定位置相同的子字符串,例如 PYTPYO 都是 PYTHON 的子序列。

这一问题的暴力解法复杂度相当可怕,因为每个长度为 m 的字符串共有 2^m 个子序列,因此一般采用动态规划(Dynamic Programming)的算法来解决。首先需要构造 LCS 问题的最优子结构:

设定字符串 X = [x1,x2,...,xm] 的第 i 个前缀为 Xi = [x1, x2,...,xi]X0 为空;假设两个字符串 X = [x1, x2,...,xm]Y = [y1, y2,...,yn] 的 LCS 为 Z = [z1, z2,...,zk],则可以将问题分解为:

  1. 如果 xm == yn,则 zk == xm == ynZk-1Xm-1Yn-1 的一个 LCS;
  2. 如果 xm != yn,且 zk != xm,则 ZXm-1Yn 的一个 LCS;
  3. 如果 xm != yn,且 zk != yn,则 ZXmYn-1 的一个 LCS。

由此可以找到 LCS 问题的重叠子问题中的递归解,设定二维数组 subs[i][j] 存储了 XiYj 的 LCS 的长度,则有:

  1. i == 0 or j == 0 时,subs[i][j] = 0
  2. i > 0 and j > 0 and xi == yj 时,subs[i][j] = subs[i-1][j-1] + 1
  3. i > 0 and j > 0 and xi != yj 时,subs[i][j] = max(subs[i][j-1], subs[i-1][j])

转化成代码:

def LCS_lengths(X, Y):
  subs = []

  # 初始化二维表
  for _ in range(X.length + 1): # 长度为 X.length + 1 是为了保存 X0
    subs.append([0] * (Y.length + 1))
  # 这里有一个坑,考虑一下为什么不可以用下面的方式进行初始化?
  # subs = [[0] * (Y.length + 1)] * (X.length + 1)
  for i in range(1, X.length + 1):
    for j in range(1, X.length + 1):
      if X[i-1] == Y[j-1]: # 字符串中下标是从 0 开始的,但这里的 i, j 是从 X1, Y1 开始的
        subs[i][j] = subs[i-1][j-1] + 1
      elif subs[i][j-1] >= subs[i-1][j]:
        subs[i][j] = subs[i][j-1]
      else:
        subs[i][j] = subs[i-1][j]
  return subs

检验一下:

X = String('ABCBDAB')
Y = String('BDCABA')
subs = LCS_lengths(X, Y)
print(subs)
"""
[[0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 1, 1, 1],
 [0, 1, 1, 1, 1, 2, 2],
 [0, 1, 1, 2, 2, 2, 2],
 [0, 1, 1, 2, 2, 3, 3],
 [0, 1, 2, 2, 2, 3, 3],
 [0, 1, 2, 2, 3, 3, 4],
 [0, 1, 2, 2, 3, 4, 4]]
"""

# 最后一个元素就是 Xm 与 Yn 的 LCS 长度:
print(subs[-1][-1])
# 4

上面的方法只帮助我们找到了 LCS 的长度,如果想要一个最常子序列的字符串呢?这时需要在生成 subs 的过程中记录每一次的比较,方便我们进行回溯:

def LCS_lengths(X, Y):
  subs = []
  road_map = []
  for _ in range(X.length + 1):
    subs.append([0] * (Y.length + 1))
    road_map.append([0] * (Y.length + 1))
  for i in range(1, X.length + 1):
    for j in range(1, Y.length + 1):
      if X[i-1] == Y[j-1]:
        subs[i][j] = subs[i-1][j-1] + 1
        road_map[i][j] = 'M' # Match
      elif subs[i][j-1] >= subs[i-1][j]:
        subs[i][j] = subs[i][j-1]
        road_map[i][j] = 'Y' # find from Yj-1
      else:
        subs[i][j] = subs[i-1][j]
        road_map[i][j] = 'X' # find from Xi-1
  return subs, road_map
def LCS_find(X, Y):
  _, road_map = LCS_lengths(X, Y)
  def _find(road, X, i, j, lcs):
    if i == 0 or j == 0:
      return
    if road[i][j] == 'M':
      _find(road, X, i-1, j-1, lcs)
      lcs.append(X[i])
    elif road[i][j] == 'Y':
      _find(road, X, i, j-1, lcs)
    else:
      _find(road, X, i-1, j, lcs)

    lcs = []
    _find(road_map, X, X.length, Y.length, lcs)
    return lcs
print(LCS_find(X, Y))
# ['B', 'C', 'B', 'A']