bk-tree  0.1.4
Header-only Burkhard-Keller tree library
bktree.hpp
1 //
2 // bk-tree Header-only Burkhard-Keller tree library
3 // Copyright (C) 2020-2023 John Law
4 //
5 // This file is part of bk-tree.
6 //
7 // bk-tree is free software: you can redistribute it and/or modify
8 // it under the terms of the GNU General Public License as published by
9 // the Free Software Foundation, either version 3 of the License, or
10 // (at your option) any later version.
11 //
12 // bk-tree is distributed in the hope that it will be useful,
13 // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 // GNU General Public License for more details.
16 //
17 // You should have received a copy of the GNU General Public License
18 // along with bk-tree. If not, see <https://www.gnu.org/licenses/>.
19 //
20 
21 #pragma once
22 
23 #ifndef BK_MATRIX_INITIAL_SIZE
24 #define BK_MATRIX_INITIAL_SIZE 0
25 #endif
26 #ifndef BK_LCS_MATRIX_INITIAL_SIZE
27 #define BK_LCS_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
28 #endif
29 #ifndef BK_ED_MATRIX_INITIAL_SIZE
30 #define BK_ED_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
31 #endif
32 #ifndef BK_LEE_ALPHABET_SIZE
33 #define BK_LEE_ALPHABET_SIZE 26
34 #endif
35 #ifndef BK_TREE_INITIAL_SIZE
36 #define BK_TREE_INITIAL_SIZE 0
37 #endif
38 #include <algorithm>
39 #include <cstddef>
40 #include <iterator>
41 #include <limits>
42 #include <map>
43 #include <memory>
44 #include <queue>
45 #include <string>
46 #include <vector>
47 
51 namespace bk_tree {
52 
53 using integer_type = std::uint64_t;
54 
58 namespace metrics {
59 
63 template <typename Metric>
64 class Distance {
65 public:
66  integer_type operator()(std::string_view s, std::string_view t) const {
67  return (static_cast<Metric const *>(this))->compute_distance(s, t);
68  }
69 };
70 
78 class LengthDistance final : public Distance<LengthDistance> {
79 public:
80  explicit LengthDistance(){};
81  integer_type compute_distance(std::string_view s, std::string_view t) const {
82  std::int64_t const d = s.length() - t.length();
83  return d >= 0 ? d : -d;
84  }
85 };
86 
92 class IdentityDistance final : public Distance<IdentityDistance> {
93 public:
94  explicit IdentityDistance(){};
95  integer_type compute_distance(std::string_view, std::string_view) const {
96  return integer_type{1};
97  }
98 };
99 
107 class LeeDistance final : public Distance<LeeDistance> {
108  integer_type m_alphabet_size;
109 
110 public:
111  explicit LeeDistance(integer_type alphabet_size = BK_LEE_ALPHABET_SIZE)
112  : m_alphabet_size(alphabet_size){};
113  integer_type compute_distance(std::string_view s, std::string_view t) const {
114  const integer_type M = s.length(), N = t.length();
115  if (M != N) {
116  return std::numeric_limits<integer_type>::max();
117  }
118  const integer_type m_comparsion_size = M;
119  integer_type counter = 0, diff;
120  for (integer_type i = 0; i < m_comparsion_size; ++i) {
121  diff = std::abs(s[i] - t[i]);
122  counter += std::min(diff, m_alphabet_size - diff);
123  }
124  return counter;
125  }
126 };
127 
143 class LCSubseqDistance final : public Distance<LCSubseqDistance> {
144  mutable std::vector<integer_type> m_current, m_previous;
145 
146 public:
147  explicit LCSubseqDistance(size_t initial_size = BK_LCS_MATRIX_INITIAL_SIZE)
148  : m_current(initial_size), m_previous(initial_size){};
149  integer_type compute_distance(std::string_view s, std::string_view t) const {
150  const integer_type M = s.length(), N = t.length();
151  if (M == 0 || N == 0) {
152  return 0;
153  }
154  if (m_current.size() <= N || m_previous.size() <= N) {
155  m_current.resize(N + 1);
156  m_previous.resize(N + 1);
157  }
158  std::fill(m_previous.begin(), m_previous.end(), 0);
159  for (integer_type i = 1; i <= M; ++i) {
160  for (integer_type j = 1; j <= N; ++j) {
161  if (s[i - 1] == t[j - 1]) {
162  m_current[j] = m_previous[j - 1] + 1;
163  } else {
164  m_current[j] = std::max(m_previous[j], m_current[j - 1]);
165  }
166  }
167  m_previous = m_current;
168  }
169  return m_previous[N];
170  }
171 };
172 
180 class HammingDistance final : public Distance<HammingDistance> {
181 public:
182  explicit HammingDistance() = default;
183  integer_type compute_distance(std::string_view s, std::string_view t) const {
184  const integer_type M = s.length(), N = t.length();
185  if (M != N) {
186  return std::numeric_limits<integer_type>::max();
187  }
188  const integer_type m_comparsion_size = M;
189  integer_type counter = 0;
190  for (integer_type i = 0; i < m_comparsion_size; ++i) {
191  counter += (s[i] != t[i]);
192  }
193  return counter;
194  }
195 };
196 
221 class EditDistance final : public Distance<EditDistance> {
222  mutable std::vector<std::vector<integer_type>> m_matrix;
223 
224 public:
225  explicit EditDistance(size_t initial_size = BK_ED_MATRIX_INITIAL_SIZE)
226  : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
227  integer_type compute_distance(std::string_view s, std::string_view t) const {
228  const integer_type M = s.length(), N = t.length();
229  if (M == 0 || N == 0) {
230  return N + M;
231  }
232  if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
233  std::vector<std::vector<integer_type>> a_matrix(M + 1,
234  std::vector<integer_type>(N + 1));
235  m_matrix.swap(a_matrix);
236  }
237  for (integer_type i = 1; i <= M; ++i) {
238  m_matrix[i][0] = i;
239  }
240  for (integer_type j = 1; j <= N; ++j) {
241  m_matrix[0][j] = j;
242  }
243  for (integer_type j = 1; j <= N; ++j) {
244  for (integer_type i = 1; i <= M; ++i) {
245  m_matrix[i][j] = std::min(
246  {m_matrix[i][j - 1] + 1 /*Insertion*/, m_matrix[i - 1][j] + 1 /*Deletion*/,
247  m_matrix[i - 1][j - 1] + (s[i - 1] != t[j - 1]) /*Substitution*/});
248  }
249  }
250  return m_matrix[M][N];
251  }
252 };
253 
259 class DamerauLevenshteinDistance final : public Distance<DamerauLevenshteinDistance> {
260  mutable std::vector<std::vector<integer_type>> m_matrix;
261 
262 public:
263  explicit DamerauLevenshteinDistance(size_t initial_size = BK_MATRIX_INITIAL_SIZE)
264  : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
265  integer_type compute_distance(std::string_view s, std::string_view t) const {
266  const integer_type M = s.length(), N = t.length();
267  if (M == 0 || N == 0) {
268  return N + M;
269  }
270  if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
271  std::vector<std::vector<integer_type>> a_matrix(M + 1,
272  std::vector<integer_type>(N + 1));
273  m_matrix.swap(a_matrix);
274  }
275  for (integer_type i = 1; i <= M; ++i) {
276  m_matrix[i][0] = i;
277  }
278  for (integer_type j = 1; j <= N; ++j) {
279  m_matrix[0][j] = j;
280  }
281  for (integer_type j = 1; j <= N; ++j) {
282  for (integer_type i = 1; i <= M; ++i) {
283  m_matrix[i][j] = std::min(
284  {m_matrix[i][j - 1] + 1 /*Insertion*/, m_matrix[i - 1][j] + 1 /*Deletion*/,
285  m_matrix[i - 1][j - 1] + (s[i - 1] == t[j - 1] ? 0 : 1) /*Substitution*/});
286  if (i > 1 && j > 1 && s[i] == t[j - 1] && s[i - 1] == t[j]) {
287  m_matrix[i][j] =
288  std::min(m_matrix[i][j], m_matrix[i - 2][j - 2] + 1 /*Transposition*/);
289  }
290  }
291  }
292  return m_matrix[M][N];
293  }
294 };
295 
296 } // namespace metrics
297 
298 namespace helpers {
299 std::false_type is_metric_impl(...);
300 
301 template <typename Metric>
302 std::true_type is_metric_impl(const volatile metrics::Distance<Metric> &);
303 
304 template <typename Metric>
305 using is_metric = decltype(is_metric_impl(std::declval<Metric &>()));
306 } // namespace helpers
307 
308 template <typename Metric>
309 class BKTree;
310 template <typename Metric>
311 class BKTreeNode;
312 
313 using ResultEntry = std::pair<std::string, int>;
314 using ResultList = std::vector<ResultEntry>;
315 
316 template <typename Metric>
317 class BKTreeNode {
318  friend class BKTree<Metric>;
319  using metric_type = Metric;
320  using node_type = BKTreeNode<metric_type>;
321 
322  BKTreeNode(std::string_view value) : m_word(value) {}
323  bool _insert(std::string_view value, const metric_type &distance);
324  bool _erase(std::string_view value, const metric_type &distance);
325  void _find(ResultList &output, std::string_view value, const int &limit,
326  const metric_type &metric) const;
327  ResultList _find_wrapper(std::string_view value, const int &limit,
328  const metric_type &metric) const;
329 
330  std::map<int, std::unique_ptr<node_type>> m_children;
331  std::string m_word;
332 
333  friend std::ostream &operator<<(std::ostream &oss, const BKTreeNode &node) {
334  oss << node.m_word;
335  return oss;
336  }
337 
338 public:
339  std::string_view word() const noexcept { return m_word; }
340 };
341 
345 template <typename Metric>
346 class BKTree {
347  static_assert(helpers::is_metric<Metric>::value, "Metric must be of type Distance");
348 
349  using metric_type = Metric;
350  using node_type = typename BKTreeNode<metric_type>::node_type;
351 
352 public:
356  class Iterator {
357  public:
358  using iterator_category = std::forward_iterator_tag;
359  using difference_type = std::ptrdiff_t;
360  using value_type = std::unique_ptr<node_type>;
361  using pointer = std::unique_ptr<node_type> *;
362  using reference = std::unique_ptr<node_type> &;
363 
364  public:
365  Iterator() = default;
366  Iterator(pointer ptr) : m_pointer(ptr) {}
367 
368  pointer operator->() { return m_pointer; }
369 
370  reference operator*() const { return *m_pointer; }
371 
372  Iterator &operator++() {
373  if (m_pointer == nullptr) {
374  throw std::out_of_range("No more tree node");
375  }
376  for (auto &[_, child] : (*m_pointer)->m_children) {
377  m_queue.push(&child);
378  }
379  if (m_queue.empty()) {
380  m_pointer = nullptr;
381  } else {
382  m_pointer = m_queue.front();
383  m_queue.pop();
384  }
385  return *this;
386  }
387 
388  Iterator operator++(int) {
389  Iterator tmp{*this};
390  ++(*this);
391  return tmp;
392  }
393 
394  friend bool operator==(const Iterator &a, const Iterator &b) {
395  return a.m_pointer == b.m_pointer;
396  };
397 
398  friend bool operator!=(const Iterator &a, const Iterator &b) {
399  return a.m_pointer != b.m_pointer;
400  };
401 
402  private:
403  pointer m_pointer;
404  std::queue<pointer> m_queue;
405  };
406 
407 public:
408  BKTree(const metric_type &distance = Metric())
409  : m_root(nullptr), m_metric(distance), m_tree_size(BK_TREE_INITIAL_SIZE) {}
410 
411  BKTree(std::initializer_list<std::string_view> list)
412  : m_root(nullptr), m_metric(Metric()), m_tree_size(BK_TREE_INITIAL_SIZE) {
413  for (auto &str : list) {
414  insert(str);
415  }
416  }
417 
418  BKTree(const BKTree &other) : BKTree(other.m_metric) {
419  if (other.m_root == nullptr) {
420  return;
421  }
422  std::queue<std::unique_ptr<node_type> const *> bq;
423  bq.push(&(other.m_root));
424  while (!bq.empty()) {
425  auto *nptr = bq.front();
426  bq.pop();
427  this->insert((*nptr)->m_word);
428  for (auto &[_, child_node] : (*nptr)->m_children) {
429  bq.push(&child_node);
430  }
431  }
432  }
433 
434  BKTree(BKTree &&other) noexcept
435  : m_root(std::exchange(other.m_root, nullptr)), m_tree_size(other.m_tree_size) {}
436 
437  BKTree &operator=(const BKTree &other) {
438  if (this == &other) {
439  return *this;
440  }
441  BKTree temp(other);
442  std::swap(m_root, temp.m_root);
443  std::swap(m_tree_size, temp.m_tree_size);
444  return *this;
445  }
446 
447  BKTree &operator=(BKTree &&other) noexcept {
448  std::swap(m_root, other.m_root);
449  std::swap(m_tree_size, other.m_tree_size);
450  return *this;
451  }
452 
453  ~BKTree() = default;
454 
455  bool insert(std::string_view value);
456  bool erase(std::string_view value);
457  size_t size() const noexcept { return m_tree_size; }
458  bool empty() const noexcept { return m_tree_size == 0; }
459  [[nodiscard]] ResultList find(std::string_view value, const int &limit) const;
460 
461  Iterator begin() { return Iterator(&m_root); }
462  Iterator end() { return Iterator(); }
463 
464 private:
465  std::unique_ptr<node_type> m_root;
466  const metric_type m_metric;
467  size_t m_tree_size;
468 };
469 
470 template <typename Metric>
471 bool BKTreeNode<Metric>::_insert(std::string_view value,
472  const metric_type &distance_metric) {
473  const int distance_between = distance_metric(value, m_word);
474  bool inserted = false;
475  if (distance_between >= 0) {
476  auto it = m_children.find(distance_between);
477  if (it == m_children.end()) {
478  m_children.emplace(std::make_pair(
479  distance_between, std::unique_ptr<node_type>(new node_type(value))));
480  inserted = true;
481  } else {
482  inserted = it->second->_insert(value, distance_metric);
483  }
484  }
485  return inserted;
486 }
487 
488 template <typename Metric>
489 bool BKTreeNode<Metric>::_erase(std::string_view value,
490  const metric_type &distance_metric) {
491  bool erased = false;
492  const int distance_between = distance_metric(value, m_word);
493  auto it = m_children.find(distance_between);
494  if (it != m_children.end()) {
495  if (it->second->m_word == value) {
496  auto node = std::move(it->second);
497  m_children.erase(it);
498  std::queue<std::unique_ptr<node_type> const *> bq;
499  for (auto const &[_, child_node] : node->m_children) {
500  bq.push(&child_node);
501  }
502  while (!bq.empty()) {
503  auto *node = bq.front();
504  bq.pop();
505  for (auto const &[_, child_node] : (*node)->m_children) {
506  bq.push(&child_node);
507  }
508  _insert((*node)->m_word, distance_metric);
509  }
510  erased = true;
511  } else {
512  erased = it->second->_erase(value, distance_metric);
513  }
514  } else {
515  for (auto const &[_, child] : m_children) {
516  if (child->_erase(value, distance_metric)) {
517  return true;
518  }
519  }
520  }
521  return erased;
522 }
523 
524 template <typename Metric>
525 void BKTreeNode<Metric>::_find(ResultList &output, std::string_view value,
526  const int &limit, const metric_type &metric) const {
527  const int distance = metric(value, m_word);
528  if (distance <= limit) {
529  output.push_back({m_word, distance});
530  }
531  for (auto const &[dist, node] : m_children) {
532  if (std::abs(dist - distance) <= limit) {
533  node->_find(output, value, limit, metric);
534  }
535  }
536 }
537 
538 template <typename Metric>
539 ResultList BKTreeNode<Metric>::_find_wrapper(std::string_view value, const int &limit,
540  const metric_type &metric) const {
541  ResultList output;
542  _find(output, value, limit, metric);
543  return output;
544 }
545 
546 template <typename Metric>
547 bool BKTree<Metric>::insert(std::string_view value) {
548  bool inserted = false;
549  if (m_root == nullptr) {
550  m_root = std::unique_ptr<node_type>(new node_type(value));
551  ++m_tree_size;
552  inserted = true;
553  } else if (m_root->_insert(value, m_metric)) {
554  ++m_tree_size;
555  inserted = true;
556  }
557  return inserted;
558 }
559 
560 template <typename Metric>
561 bool BKTree<Metric>::erase(std::string_view value) {
562  bool erased = false;
563  if (m_root == nullptr) {
564  erased = true;
565  } else if (m_root->m_word == value) {
566  if (m_tree_size > 1) {
567  auto &replacement_node = m_root->m_children.begin()->second;
568  std::queue<std::unique_ptr<node_type> const *> bq;
569  for (bool first = true; auto const &[_, node] : m_root->m_children) {
570  if (first) {
571  first = false;
572  continue;
573  }
574  bq.push(&node);
575  }
576  while (!bq.empty()) {
577  auto node = bq.front();
578  bq.pop();
579  for (auto const &[_, child] : (*node)->m_children) {
580  bq.push(&child);
581  }
582  replacement_node->_insert((*node)->m_word, m_metric);
583  }
584  m_root = std::move(replacement_node);
585  } else {
586  m_root.reset(nullptr);
587  }
588  --m_tree_size;
589  erased = true;
590  } else if (m_root->_erase(value, m_metric)) {
591  --m_tree_size;
592  erased = true;
593  }
594  return erased;
595 }
596 
597 template <typename Metric>
598 ResultList BKTree<Metric>::find(std::string_view value, const int &limit) const {
599  if (m_root == nullptr) {
600  return ResultList{};
601  }
602  return m_root->_find_wrapper(value, limit, m_metric);
603 }
604 
605 } // namespace bk_tree
BK-tree class iterator.
Definition: bktree.hpp:356
BK-tree template class.
Definition: bktree.hpp:346
Damerau–Levenshtein metric.
Definition: bktree.hpp:259
Metric interface for string distances.
Definition: bktree.hpp:64
Edit distance metric.
Definition: bktree.hpp:221
Hamming distance metric.
Definition: bktree.hpp:180
Identity metric.
Definition: bktree.hpp:92
Longest Common Subsequence distance metric.
Definition: bktree.hpp:143
Lee distance metric.
Definition: bktree.hpp:107
Length metric.
Definition: bktree.hpp:78
BK-tree namespace.
Definition: bktree.hpp:51