bk-tree 0.1.4
Header-only Burkhard-Keller tree library
Loading...
Searching...
No Matches
bktree.hpp
1//
2// bk-tree Header-only Burkhard-Keller tree library
3// Copyright (C) 2020-2023 John Law
4//
5// This file is part of bk-tree.
6//
7// bk-tree is free software: you can redistribute it and/or modify
8// it under the terms of the GNU General Public License as published by
9// the Free Software Foundation, either version 3 of the License, or
10// (at your option) any later version.
11//
12// bk-tree is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU General Public License for more details.
16//
17// You should have received a copy of the GNU General Public License
18// along with bk-tree. If not, see <https://www.gnu.org/licenses/>.
19//
20
21#pragma once
22
23#ifndef BK_MATRIX_INITIAL_SIZE
24#define BK_MATRIX_INITIAL_SIZE 0
25#endif
26#ifndef BK_LCS_MATRIX_INITIAL_SIZE
27#define BK_LCS_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
28#endif
29#ifndef BK_ED_MATRIX_INITIAL_SIZE
30#define BK_ED_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
31#endif
32#ifndef BK_LEE_ALPHABET_SIZE
33#define BK_LEE_ALPHABET_SIZE 26
34#endif
35#ifndef BK_TREE_INITIAL_SIZE
36#define BK_TREE_INITIAL_SIZE 0
37#endif
38#include <algorithm>
39#include <cstddef>
40#include <iterator>
41#include <limits>
42#include <map>
43#include <memory>
44#include <queue>
45#include <string>
46#include <utility>
47#include <vector>
48
52namespace bk_tree {
53
54using integer_type = std::uint64_t;
55
59namespace metrics {
60
64template <typename Metric>
65class Distance {
66public:
67 integer_type operator()(std::string_view s, std::string_view t) const {
68 return (static_cast<Metric const *>(this))->compute_distance(s, t);
69 }
70};
71
79class LengthDistance final : public Distance<LengthDistance> {
80public:
81 explicit LengthDistance(){};
82 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
83 return s.length() > t.length() ? s.length() - t.length() : t.length() - s.length();
84 }
85};
86
92class IdentityDistance final : public Distance<IdentityDistance> {
93public:
94 explicit IdentityDistance(){};
95 integer_type compute_distance(std::string_view, std::string_view) const noexcept {
96 return integer_type{1};
97 }
98};
99
107class LeeDistance final : public Distance<LeeDistance> {
108 integer_type m_alphabet_size;
109
110public:
111 explicit LeeDistance(integer_type alphabet_size = BK_LEE_ALPHABET_SIZE)
112 : m_alphabet_size(alphabet_size){};
113 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
114 const integer_type M = s.length(), N = t.length();
115 if (M != N) {
116 return std::numeric_limits<integer_type>::max();
117 }
118 const integer_type m_comparison_size = M;
119 integer_type counter = 0, diff;
120 for (integer_type i = 0; i < m_comparison_size; ++i) {
121 diff = std::abs(s[i] - t[i]);
122 counter += std::min(diff, m_alphabet_size - diff);
123 }
124 return counter;
125 }
126};
127
143class LCSubseqDistance final : public Distance<LCSubseqDistance> {
144 mutable std::vector<integer_type> m_current, m_previous;
145
146public:
147 explicit LCSubseqDistance(size_t initial_size = BK_LCS_MATRIX_INITIAL_SIZE)
148 : m_current(initial_size), m_previous(initial_size){};
149 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
150 const integer_type M = s.length(), N = t.length();
151 if (M == 0 || N == 0) {
152 return 0;
153 }
154 if (m_current.size() <= N || m_previous.size() <= N) {
155 m_current.resize(N + 1);
156 m_previous.resize(N + 1);
157 }
158 std::fill(m_previous.begin(), m_previous.end(), 0);
159 for (integer_type i = 1; i <= M; ++i) {
160 for (integer_type j = 1; j <= N; ++j) {
161 if (s[i - 1] == t[j - 1]) {
162 m_current[j] = m_previous[j - 1] + 1;
163 } else {
164 m_current[j] = std::max(m_previous[j], m_current[j - 1]);
165 }
166 }
167 m_previous = m_current;
168 }
169 return m_previous[N];
170 }
171};
172
180class HammingDistance final : public Distance<HammingDistance> {
181public:
182 explicit HammingDistance() = default;
183 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
184 const integer_type M = s.length(), N = t.length();
185 if (M != N) {
186 return std::numeric_limits<integer_type>::max();
187 }
188 const integer_type m_comparison_size = M;
189 integer_type counter = 0;
190 for (integer_type i = 0; i < m_comparison_size; ++i) {
191 counter += (s[i] != t[i]);
192 }
193 return counter;
194 }
195};
196
221class EditDistance final : public Distance<EditDistance> {
222 mutable std::vector<std::vector<integer_type>> m_matrix;
223
224public:
225 explicit EditDistance(size_t initial_size = BK_ED_MATRIX_INITIAL_SIZE)
226 : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
227 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
228 const integer_type M = s.length(), N = t.length();
229 if (M == 0 || N == 0) {
230 return N + M;
231 }
232 if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
233 std::vector<std::vector<integer_type>> a_matrix(M + 1,
234 std::vector<integer_type>(N + 1));
235 m_matrix.swap(a_matrix);
236 }
237 for (integer_type i = 1; i <= M; ++i) {
238 m_matrix[i][0] = i;
239 }
240 for (integer_type j = 1; j <= N; ++j) {
241 m_matrix[0][j] = j;
242 }
243 for (integer_type j = 1; j <= N; ++j) {
244 for (integer_type i = 1; i <= M; ++i) {
245 m_matrix[i][j] = std::min(
246 {m_matrix[i][j - 1] + 1 /*Insertion*/, m_matrix[i - 1][j] + 1 /*Deletion*/,
247 m_matrix[i - 1][j - 1] + (s[i - 1] != t[j - 1]) /*Substitution*/});
248 }
249 }
250 return m_matrix[M][N];
251 }
252};
253
259class DamerauLevenshteinDistance final : public Distance<DamerauLevenshteinDistance> {
260 mutable std::vector<std::vector<integer_type>> m_matrix;
261
262public:
263 explicit DamerauLevenshteinDistance(size_t initial_size = BK_MATRIX_INITIAL_SIZE)
264 : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
265 integer_type compute_distance(std::string_view s, std::string_view t) const noexcept {
266 const integer_type M = s.length(), N = t.length();
267 if (M == 0 || N == 0) {
268 return N + M;
269 }
270 if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
271 std::vector<std::vector<integer_type>> a_matrix(M + 1,
272 std::vector<integer_type>(N + 1));
273 m_matrix.swap(a_matrix);
274 }
275 for (integer_type i = 1; i <= M; ++i) {
276 m_matrix[i][0] = i;
277 }
278 for (integer_type j = 1; j <= N; ++j) {
279 m_matrix[0][j] = j;
280 }
281 for (integer_type j = 1; j <= N; ++j) {
282 for (integer_type i = 1; i <= M; ++i) {
283 m_matrix[i][j] = std::min(
284 {m_matrix[i][j - 1] + 1 /*Insertion*/, m_matrix[i - 1][j] + 1 /*Deletion*/,
285 m_matrix[i - 1][j - 1] + (s[i - 1] == t[j - 1] ? 0 : 1) /*Substitution*/});
286 if (i > 1 && j > 1 && s[i - 1] == t[j - 2] && s[i - 2] == t[j - 1]) {
287 m_matrix[i][j] =
288 std::min(m_matrix[i][j], m_matrix[i - 2][j - 2] + 1 /*Transposition*/);
289 }
290 }
291 }
292 return m_matrix[M][N];
293 }
294};
295
296} // namespace metrics
297
298namespace helpers {
299std::false_type is_metric_impl(...);
300
301template <typename Metric>
302std::true_type is_metric_impl(const volatile metrics::Distance<Metric> &);
303
304template <typename Metric>
305using is_metric = decltype(is_metric_impl(std::declval<Metric &>()));
306} // namespace helpers
307
308template <typename Metric>
309class BKTree;
310template <typename Metric>
311class BKTreeNode;
312
313using ResultEntry = std::pair<std::string, int>;
314using ResultList = std::vector<ResultEntry>;
315
316template <typename Metric>
317class BKTreeNode {
318 friend class BKTree<Metric>;
319 using metric_type = Metric;
320 using node_type = BKTreeNode<metric_type>;
321
322 BKTreeNode(std::string_view value) : m_word(value) {}
323 bool _insert(std::string_view value, const metric_type &distance);
324 bool _erase(std::string_view value, const metric_type &distance);
325 void _find(ResultList &output, std::string_view value, int limit,
326 const metric_type &metric) const;
327 ResultList _find_wrapper(std::string_view value, int limit,
328 const metric_type &metric) const;
329
330 std::map<int, std::unique_ptr<node_type>> m_children;
331 std::string m_word;
332
333 friend std::ostream &operator<<(std::ostream &oss, const BKTreeNode &node) {
334 oss << node.m_word;
335 return oss;
336 }
337
338public:
339 std::string_view word() const noexcept { return m_word; }
340};
341
345template <typename Metric>
346class BKTree {
347 static_assert(helpers::is_metric<Metric>::value, "Metric must be of type Distance");
348
349 using metric_type = Metric;
350 using node_type = typename BKTreeNode<metric_type>::node_type;
351
352public:
356 class Iterator {
357 public:
358 using iterator_category = std::forward_iterator_tag;
359 using difference_type = std::ptrdiff_t;
360 using value_type = std::unique_ptr<node_type>;
361 using pointer = std::unique_ptr<node_type> *;
362 using reference = std::unique_ptr<node_type> &;
363
364 public:
365 Iterator() = default;
366 Iterator(pointer ptr) : m_pointer(ptr) {}
367
368 pointer operator->() { return m_pointer; }
369
370 reference operator*() const { return *m_pointer; }
371
372 Iterator &operator++() {
373 if (m_pointer == nullptr) {
374 throw std::out_of_range("No more tree node");
375 }
376 for (auto &[_, child] : (*m_pointer)->m_children) {
377 m_queue.push(&child);
378 }
379 if (m_queue.empty()) {
380 m_pointer = nullptr;
381 } else {
382 m_pointer = m_queue.front();
383 m_queue.pop();
384 }
385 return *this;
386 }
387
388 Iterator operator++(int) {
389 Iterator tmp{*this};
390 ++(*this);
391 return tmp;
392 }
393
394 friend bool operator==(const Iterator &a, const Iterator &b) {
395 return a.m_pointer == b.m_pointer;
396 };
397
398 friend bool operator!=(const Iterator &a, const Iterator &b) {
399 return a.m_pointer != b.m_pointer;
400 };
401
402 private:
403 pointer m_pointer;
404 std::queue<pointer> m_queue;
405 };
406
407public:
408 BKTree(const metric_type &distance = Metric())
409 : m_root(nullptr), m_metric(distance), m_tree_size(BK_TREE_INITIAL_SIZE) {}
410
411 BKTree(std::initializer_list<std::string_view> list)
412 : m_root(nullptr), m_metric(Metric()), m_tree_size(BK_TREE_INITIAL_SIZE) {
413 for (auto &str : list) {
414 insert(str);
415 }
416 }
417
418 BKTree(const BKTree &other) : BKTree(other.m_metric) {
419 if (other.m_root == nullptr) {
420 return;
421 }
422 std::queue<std::unique_ptr<node_type> const *> bq;
423 bq.push(&(other.m_root));
424 while (!bq.empty()) {
425 auto *nptr = bq.front();
426 bq.pop();
427 this->insert((*nptr)->m_word);
428 for (auto &[_, child_node] : (*nptr)->m_children) {
429 bq.push(&child_node);
430 }
431 }
432 }
433
434 BKTree(BKTree &&other) noexcept
435 : m_root(std::exchange(other.m_root, nullptr)), m_tree_size(other.m_tree_size) {}
436
437 BKTree &operator=(const BKTree &other) {
438 if (this == &other) {
439 return *this;
440 }
441 BKTree temp(other);
442 std::swap(m_root, temp.m_root);
443 std::swap(m_tree_size, temp.m_tree_size);
444 return *this;
445 }
446
447 BKTree &operator=(BKTree &&other) noexcept {
448 std::swap(m_root, other.m_root);
449 std::swap(m_tree_size, other.m_tree_size);
450 return *this;
451 }
452
453 ~BKTree() = default;
454
455 bool insert(std::string_view value);
456 bool erase(std::string_view value);
457 size_t size() const noexcept { return m_tree_size; }
458 bool empty() const noexcept { return m_tree_size == 0; }
459 [[nodiscard]] ResultList find(std::string_view value, int limit) const;
460
461 Iterator begin() { return Iterator(&m_root); }
462 Iterator end() { return Iterator(); }
463
464private:
465 std::unique_ptr<node_type> m_root;
466 const metric_type m_metric;
467 size_t m_tree_size;
468};
469
470template <typename Metric>
471bool BKTreeNode<Metric>::_insert(std::string_view value,
472 const metric_type &distance_metric) {
473 const int distance_between = distance_metric(value, m_word);
474 bool inserted = false;
475 if (distance_between >= 0) {
476 auto it = m_children.find(distance_between);
477 if (it == m_children.end()) {
478 m_children.emplace(std::make_pair(
479 distance_between, std::unique_ptr<node_type>(new node_type(value))));
480 inserted = true;
481 } else {
482 inserted = it->second->_insert(value, distance_metric);
483 }
484 }
485 return inserted;
486}
487
488template <typename Metric>
489bool BKTreeNode<Metric>::_erase(std::string_view value,
490 const metric_type &distance_metric) {
491 bool erased = false;
492 const int distance_between = distance_metric(value, m_word);
493 auto it = m_children.find(distance_between);
494 if (it != m_children.end()) {
495 if (it->second->m_word == value) {
496 auto node = std::move(it->second);
497 m_children.erase(it);
498 std::queue<std::unique_ptr<node_type> const *> bq;
499 for (auto const &[_, child_node] : node->m_children) {
500 bq.push(&child_node);
501 }
502 while (!bq.empty()) {
503 auto *node = bq.front();
504 bq.pop();
505 for (auto const &[_, child_node] : (*node)->m_children) {
506 bq.push(&child_node);
507 }
508 _insert((*node)->m_word, distance_metric);
509 }
510 erased = true;
511 } else {
512 erased = it->second->_erase(value, distance_metric);
513 }
514 } else {
515 for (auto const &[_, child] : m_children) {
516 if (child->_erase(value, distance_metric)) {
517 return true;
518 }
519 }
520 }
521 return erased;
522}
523
524template <typename Metric>
525void BKTreeNode<Metric>::_find(ResultList &output, std::string_view value,
526 int limit, const metric_type &metric) const {
527 const int distance = metric(value, m_word);
528 if (distance <= limit) {
529 output.push_back({m_word, distance});
530 }
531 for (auto const &[dist, node] : m_children) {
532 if (std::abs(dist - distance) <= limit) {
533 node->_find(output, value, limit, metric);
534 }
535 }
536}
537
538template <typename Metric>
539ResultList BKTreeNode<Metric>::_find_wrapper(std::string_view value, int limit,
540 const metric_type &metric) const {
541 ResultList output;
542 _find(output, value, limit, metric);
543 return output;
544}
545
546template <typename Metric>
547bool BKTree<Metric>::insert(std::string_view value) {
548 bool inserted = false;
549 if (m_root == nullptr) {
550 m_root = std::unique_ptr<node_type>(new node_type(value));
551 ++m_tree_size;
552 inserted = true;
553 } else if (m_root->_insert(value, m_metric)) {
554 ++m_tree_size;
555 inserted = true;
556 }
557 return inserted;
558}
559
560template <typename Metric>
561bool BKTree<Metric>::erase(std::string_view value) {
562 bool erased = false;
563 if (m_root == nullptr) {
564 erased = false;
565 } else if (m_root->m_word == value) {
566 if (m_tree_size > 1) {
567 auto &replacement_node = m_root->m_children.begin()->second;
568 std::queue<std::unique_ptr<node_type> const *> bq;
569 for (bool first = true; auto const &[_, node] : m_root->m_children) {
570 if (first) {
571 first = false;
572 continue;
573 }
574 bq.push(&node);
575 }
576 while (!bq.empty()) {
577 auto node = bq.front();
578 bq.pop();
579 for (auto const &[_, child] : (*node)->m_children) {
580 bq.push(&child);
581 }
582 replacement_node->_insert((*node)->m_word, m_metric);
583 }
584 m_root = std::move(replacement_node);
585 } else {
586 m_root.reset(nullptr);
587 }
588 --m_tree_size;
589 erased = true;
590 } else if (m_root->_erase(value, m_metric)) {
591 --m_tree_size;
592 erased = true;
593 }
594 return erased;
595}
596
597template <typename Metric>
598ResultList BKTree<Metric>::find(std::string_view value, int limit) const {
599 if (m_root == nullptr) {
600 return ResultList{};
601 }
602 return m_root->_find_wrapper(value, limit, m_metric);
603}
604
605} // namespace bk_tree
BK-tree class iterator.
Definition bktree.hpp:356
BK-tree template class.
Definition bktree.hpp:346
Damerau–Levenshtein metric.
Definition bktree.hpp:259
Metric interface for string distances.
Definition bktree.hpp:65
Edit distance metric.
Definition bktree.hpp:221
Hamming distance metric.
Definition bktree.hpp:180
Identity metric.
Definition bktree.hpp:92
Longest Common Subsequence distance metric.
Definition bktree.hpp:143
Lee distance metric.
Definition bktree.hpp:107
Length metric.
Definition bktree.hpp:79
BK-tree namespace.
Definition bktree.hpp:52