23 #ifndef BK_MATRIX_INITIAL_SIZE
24 #define BK_MATRIX_INITIAL_SIZE 0
26 #ifndef BK_LCS_MATRIX_INITIAL_SIZE
27 #define BK_LCS_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
29 #ifndef BK_ED_MATRIX_INITIAL_SIZE
30 #define BK_ED_MATRIX_INITIAL_SIZE BK_MATRIX_INITIAL_SIZE
32 #ifndef BK_LEE_ALPHABET_SIZE
33 #define BK_LEE_ALPHABET_SIZE 26
35 #ifndef BK_TREE_INITIAL_SIZE
36 #define BK_TREE_INITIAL_SIZE 0
53 using integer_type = std::uint64_t;
63 template <
typename Metric>
66 integer_type operator()(std::string_view s, std::string_view t)
const {
67 return (
static_cast<Metric
const *
>(
this))->compute_distance(s, t);
81 integer_type compute_distance(std::string_view s, std::string_view t)
const {
82 std::int64_t
const d = s.length() - t.length();
83 return d >= 0 ? d : -d;
95 integer_type compute_distance(std::string_view, std::string_view)
const {
96 return integer_type{1};
108 integer_type m_alphabet_size;
111 explicit LeeDistance(integer_type alphabet_size = BK_LEE_ALPHABET_SIZE)
112 : m_alphabet_size(alphabet_size){};
113 integer_type compute_distance(std::string_view s, std::string_view t)
const {
114 const integer_type M = s.length(), N = t.length();
116 return std::numeric_limits<integer_type>::max();
118 const integer_type m_comparsion_size = M;
119 integer_type counter = 0, diff;
120 for (integer_type i = 0; i < m_comparsion_size; ++i) {
121 diff = std::abs(s[i] - t[i]);
122 counter += std::min(diff, m_alphabet_size - diff);
144 mutable std::vector<integer_type> m_current, m_previous;
148 : m_current(initial_size), m_previous(initial_size){};
149 integer_type compute_distance(std::string_view s, std::string_view t)
const {
150 const integer_type M = s.length(), N = t.length();
151 if (M == 0 || N == 0) {
154 if (m_current.size() <= N || m_previous.size() <= N) {
155 m_current.resize(N + 1);
156 m_previous.resize(N + 1);
158 std::fill(m_previous.begin(), m_previous.end(), 0);
159 for (integer_type i = 1; i <= M; ++i) {
160 for (integer_type j = 1; j <= N; ++j) {
161 if (s[i - 1] == t[j - 1]) {
162 m_current[j] = m_previous[j - 1] + 1;
164 m_current[j] = std::max(m_previous[j], m_current[j - 1]);
167 m_previous = m_current;
169 return m_previous[N];
183 integer_type compute_distance(std::string_view s, std::string_view t)
const {
184 const integer_type M = s.length(), N = t.length();
186 return std::numeric_limits<integer_type>::max();
188 const integer_type m_comparsion_size = M;
189 integer_type counter = 0;
190 for (integer_type i = 0; i < m_comparsion_size; ++i) {
191 counter += (s[i] != t[i]);
222 mutable std::vector<std::vector<integer_type>> m_matrix;
225 explicit EditDistance(
size_t initial_size = BK_ED_MATRIX_INITIAL_SIZE)
226 : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
227 integer_type compute_distance(std::string_view s, std::string_view t)
const {
228 const integer_type M = s.length(), N = t.length();
229 if (M == 0 || N == 0) {
232 if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
233 std::vector<std::vector<integer_type>> a_matrix(M + 1,
234 std::vector<integer_type>(N + 1));
235 m_matrix.swap(a_matrix);
237 for (integer_type i = 1; i <= M; ++i) {
240 for (integer_type j = 1; j <= N; ++j) {
243 for (integer_type j = 1; j <= N; ++j) {
244 for (integer_type i = 1; i <= M; ++i) {
245 m_matrix[i][j] = std::min(
246 {m_matrix[i][j - 1] + 1 , m_matrix[i - 1][j] + 1 ,
247 m_matrix[i - 1][j - 1] + (s[i - 1] != t[j - 1]) });
250 return m_matrix[M][N];
260 mutable std::vector<std::vector<integer_type>> m_matrix;
264 : m_matrix(initial_size, std::vector<integer_type>(initial_size)){};
265 integer_type compute_distance(std::string_view s, std::string_view t)
const {
266 const integer_type M = s.length(), N = t.length();
267 if (M == 0 || N == 0) {
270 if (m_matrix.size() <= M || m_matrix[0].size() <= N) {
271 std::vector<std::vector<integer_type>> a_matrix(M + 1,
272 std::vector<integer_type>(N + 1));
273 m_matrix.swap(a_matrix);
275 for (integer_type i = 1; i <= M; ++i) {
278 for (integer_type j = 1; j <= N; ++j) {
281 for (integer_type j = 1; j <= N; ++j) {
282 for (integer_type i = 1; i <= M; ++i) {
283 m_matrix[i][j] = std::min(
284 {m_matrix[i][j - 1] + 1 , m_matrix[i - 1][j] + 1 ,
285 m_matrix[i - 1][j - 1] + (s[i - 1] == t[j - 1] ? 0 : 1) });
286 if (i > 1 && j > 1 && s[i] == t[j - 1] && s[i - 1] == t[j]) {
288 std::min(m_matrix[i][j], m_matrix[i - 2][j - 2] + 1 );
292 return m_matrix[M][N];
299 std::false_type is_metric_impl(...);
301 template <
typename Metric>
304 template <
typename Metric>
305 using is_metric = decltype(is_metric_impl(std::declval<Metric &>()));
308 template <
typename Metric>
310 template <
typename Metric>
313 using ResultEntry = std::pair<std::string, int>;
314 using ResultList = std::vector<ResultEntry>;
316 template <
typename Metric>
318 friend class BKTree<Metric>;
319 using metric_type = Metric;
320 using node_type = BKTreeNode<metric_type>;
322 BKTreeNode(std::string_view value) : m_word(value) {}
323 bool _insert(std::string_view value,
const metric_type &distance);
324 bool _erase(std::string_view value,
const metric_type &distance);
325 void _find(ResultList &output, std::string_view value,
const int &limit,
326 const metric_type &metric)
const;
327 ResultList _find_wrapper(std::string_view value,
const int &limit,
328 const metric_type &metric)
const;
330 std::map<int, std::unique_ptr<node_type>> m_children;
333 friend std::ostream &operator<<(std::ostream &oss,
const BKTreeNode &node) {
339 std::string_view word() const noexcept {
return m_word; }
345 template <
typename Metric>
347 static_assert(helpers::is_metric<Metric>::value,
"Metric must be of type Distance");
349 using metric_type = Metric;
350 using node_type =
typename BKTreeNode<metric_type>::node_type;
358 using iterator_category = std::forward_iterator_tag;
359 using difference_type = std::ptrdiff_t;
360 using value_type = std::unique_ptr<node_type>;
361 using pointer = std::unique_ptr<node_type> *;
362 using reference = std::unique_ptr<node_type> &;
366 Iterator(pointer ptr) : m_pointer(ptr) {}
368 pointer operator->() {
return m_pointer; }
370 reference operator*()
const {
return *m_pointer; }
373 if (m_pointer ==
nullptr) {
374 throw std::out_of_range(
"No more tree node");
376 for (
auto &[_, child] : (*m_pointer)->m_children) {
377 m_queue.push(&child);
379 if (m_queue.empty()) {
382 m_pointer = m_queue.front();
395 return a.m_pointer == b.m_pointer;
399 return a.m_pointer != b.m_pointer;
404 std::queue<pointer> m_queue;
408 BKTree(
const metric_type &distance = Metric())
409 : m_root(nullptr), m_metric(distance), m_tree_size(BK_TREE_INITIAL_SIZE) {}
411 BKTree(std::initializer_list<std::string_view> list)
412 : m_root(nullptr), m_metric(Metric()), m_tree_size(BK_TREE_INITIAL_SIZE) {
413 for (
auto &str : list) {
418 BKTree(
const BKTree &other) : BKTree(other.m_metric) {
419 if (other.m_root ==
nullptr) {
422 std::queue<std::unique_ptr<node_type>
const *> bq;
423 bq.push(&(other.m_root));
424 while (!bq.empty()) {
425 auto *nptr = bq.front();
427 this->insert((*nptr)->m_word);
428 for (
auto &[_, child_node] : (*nptr)->m_children) {
429 bq.push(&child_node);
434 BKTree(BKTree &&other) noexcept
435 : m_root(std::exchange(other.m_root,
nullptr)), m_tree_size(other.m_tree_size) {}
437 BKTree &operator=(
const BKTree &other) {
438 if (
this == &other) {
442 std::swap(m_root, temp.m_root);
443 std::swap(m_tree_size, temp.m_tree_size);
447 BKTree &operator=(BKTree &&other) noexcept {
448 std::swap(m_root, other.m_root);
449 std::swap(m_tree_size, other.m_tree_size);
455 bool insert(std::string_view value);
456 bool erase(std::string_view value);
457 size_t size() const noexcept {
return m_tree_size; }
458 bool empty() const noexcept {
return m_tree_size == 0; }
459 [[nodiscard]] ResultList find(std::string_view value,
const int &limit)
const;
461 Iterator begin() {
return Iterator(&m_root); }
462 Iterator end() {
return Iterator(); }
465 std::unique_ptr<node_type> m_root;
466 const metric_type m_metric;
470 template <
typename Metric>
471 bool BKTreeNode<Metric>::_insert(std::string_view value,
472 const metric_type &distance_metric) {
473 const int distance_between = distance_metric(value, m_word);
474 bool inserted =
false;
475 if (distance_between >= 0) {
476 auto it = m_children.find(distance_between);
477 if (it == m_children.end()) {
478 m_children.emplace(std::make_pair(
479 distance_between, std::unique_ptr<node_type>(
new node_type(value))));
482 inserted = it->second->_insert(value, distance_metric);
488 template <
typename Metric>
489 bool BKTreeNode<Metric>::_erase(std::string_view value,
490 const metric_type &distance_metric) {
492 const int distance_between = distance_metric(value, m_word);
493 auto it = m_children.find(distance_between);
494 if (it != m_children.end()) {
495 if (it->second->m_word == value) {
496 auto node = std::move(it->second);
497 m_children.erase(it);
498 std::queue<std::unique_ptr<node_type>
const *> bq;
499 for (
auto const &[_, child_node] : node->m_children) {
500 bq.push(&child_node);
502 while (!bq.empty()) {
503 auto *node = bq.front();
505 for (
auto const &[_, child_node] : (*node)->m_children) {
506 bq.push(&child_node);
508 _insert((*node)->m_word, distance_metric);
512 erased = it->second->_erase(value, distance_metric);
515 for (
auto const &[_, child] : m_children) {
516 if (child->_erase(value, distance_metric)) {
524 template <
typename Metric>
525 void BKTreeNode<Metric>::_find(ResultList &output, std::string_view value,
526 const int &limit,
const metric_type &metric)
const {
527 const int distance = metric(value, m_word);
528 if (distance <= limit) {
529 output.push_back({m_word, distance});
531 for (
auto const &[dist, node] : m_children) {
532 if (std::abs(dist - distance) <= limit) {
533 node->_find(output, value, limit, metric);
538 template <
typename Metric>
539 ResultList BKTreeNode<Metric>::_find_wrapper(std::string_view value,
const int &limit,
540 const metric_type &metric)
const {
542 _find(output, value, limit, metric);
546 template <
typename Metric>
547 bool BKTree<Metric>::insert(std::string_view value) {
548 bool inserted =
false;
549 if (m_root ==
nullptr) {
550 m_root = std::unique_ptr<node_type>(
new node_type(value));
553 }
else if (m_root->_insert(value, m_metric)) {
560 template <
typename Metric>
561 bool BKTree<Metric>::erase(std::string_view value) {
563 if (m_root ==
nullptr) {
565 }
else if (m_root->m_word == value) {
566 if (m_tree_size > 1) {
567 auto &replacement_node = m_root->m_children.begin()->second;
568 std::queue<std::unique_ptr<node_type>
const *> bq;
569 for (
bool first =
true;
auto const &[_, node] : m_root->m_children) {
576 while (!bq.empty()) {
577 auto node = bq.front();
579 for (
auto const &[_, child] : (*node)->m_children) {
582 replacement_node->_insert((*node)->m_word, m_metric);
584 m_root = std::move(replacement_node);
586 m_root.reset(
nullptr);
590 }
else if (m_root->_erase(value, m_metric)) {
597 template <
typename Metric>
598 ResultList BKTree<Metric>::find(std::string_view value,
const int &limit)
const {
599 if (m_root ==
nullptr) {
602 return m_root->_find_wrapper(value, limit, m_metric);
BK-tree class iterator.
Definition: bktree.hpp:356
BK-tree template class.
Definition: bktree.hpp:346
Damerau–Levenshtein metric.
Definition: bktree.hpp:259
Metric interface for string distances.
Definition: bktree.hpp:64
Edit distance metric.
Definition: bktree.hpp:221
Hamming distance metric.
Definition: bktree.hpp:180
Identity metric.
Definition: bktree.hpp:92
Longest Common Subsequence distance metric.
Definition: bktree.hpp:143
Lee distance metric.
Definition: bktree.hpp:107
Length metric.
Definition: bktree.hpp:78
BK-tree namespace.
Definition: bktree.hpp:51