1#![allow(dead_code)]
11
12use crate::vector_select::FutureVector;
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::hash::Hash;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18use std::{io, mem};
19use tokio::sync::{RwLock, RwLockWriteGuard};
20use tracing::{error, instrument, trace};
21
22pub(crate) struct AsyncLruCacheEntry<V> {
24 value: Option<Arc<V>>,
28
29 last_used: AtomicUsize,
31}
32
33struct AsyncLruCacheInner<
35 Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
36 Value: Send + Sync,
37 IoBackend: AsyncLruCacheBackend<Key = Key, Value = Value>,
38> {
39 backend: IoBackend,
41
42 map: RwLock<HashMap<Key, AsyncLruCacheEntry<Value>>>,
44
45 lru_timer: AtomicUsize,
47
48 limit: usize,
50}
51
52pub(crate) struct AsyncLruCache<
59 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
60 V: Send + Sync,
61 B: AsyncLruCacheBackend<Key = K, Value = V>,
62>(Arc<AsyncLruCacheInner<K, V, B>>);
63
64pub(crate) trait AsyncLruCacheBackend: Send + Sync {
66 type Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync;
68 type Value: Send + Sync;
70
71 #[allow(async_fn_in_trait)] async fn load(&self, key: Self::Key) -> io::Result<Self::Value>;
74
75 #[allow(async_fn_in_trait)] async fn flush(&self, key: Self::Key, value: &Self::Value) -> io::Result<()>;
81
82 unsafe fn evict(&self, key: Self::Key, value: Self::Value);
92}
93
94impl<
95 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
96 V: Send + Sync,
97 B: AsyncLruCacheBackend<Key = K, Value = V>,
98 > AsyncLruCache<K, V, B>
99{
100 pub fn new(backend: B, size: usize) -> Self {
104 AsyncLruCache(Arc::new(AsyncLruCacheInner {
105 backend,
106 map: Default::default(),
107 lru_timer: AtomicUsize::new(0),
108 limit: size,
109 }))
110 }
111
112 pub async fn get_or_insert(&self, key: K, may_flush: bool) -> io::Result<Option<Arc<V>>> {
122 self.0.get_or_insert(key, may_flush).await
123 }
124
125 pub async fn insert(&self, key: K, value: Arc<V>, may_flush: bool) -> io::Result<bool> {
136 self.0.insert(key, value, may_flush).await
137 }
138
139 pub async fn flush(&self) -> io::Result<()> {
143 self.0.flush().await
144 }
145
146 pub async unsafe fn invalidate(&self) -> io::Result<()> {
154 unsafe { self.0.invalidate() }.await
155 }
156}
157
158impl<
159 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
160 V: Send + Sync,
161 B: AsyncLruCacheBackend<Key = K, Value = V>,
162 > AsyncLruCacheInner<K, V, B>
163{
164 #[instrument(
178 level = "trace",
179 name = "AsyncLruCache::ensure_free_entry",
180 skip_all,
181 fields(self = &self as *const _ as usize),
182 )]
183 async fn ensure_free_entry(
184 &self,
185 map: &mut RwLockWriteGuard<'_, HashMap<K, AsyncLruCacheEntry<V>>>,
186 may_evict: bool,
187 ) -> io::Result<bool> {
188 if map.len() < self.limit {
189 return Ok(true);
190 } else if !may_evict {
191 return Ok(false);
192 }
193
194 while map.len() >= self.limit {
195 trace!("{} / {} used", map.len(), self.limit);
196
197 let now = self.lru_timer.load(Ordering::Relaxed);
198 let oldest = map
199 .iter()
200 .filter(|(_key, entry)| Arc::strong_count(entry.value()) == 1)
201 .fold((0, None), |oldest, (key, entry)| {
202 assert_eq!(Arc::weak_count(entry.value()), 0);
206
207 let age = now.wrapping_sub(entry.last_used.load(Ordering::Relaxed));
208 if age >= oldest.0 {
209 (age, Some(*key))
210 } else {
211 oldest
212 }
213 });
214
215 let Some(oldest_key) = oldest.1 else {
216 error!("Cannot evict entry from cache; everything is in use");
217 return Err(io::Error::other(
218 "Cannot evict entry from cache; everything is in use",
219 ));
220 };
221
222 trace!("Removing entry with key {oldest_key:?}, aged {}", oldest.0);
223
224 let oldest_entry = map.remove(&oldest_key).unwrap();
225
226 let evicted_object = Arc::try_unwrap(oldest_entry.value.unwrap())
231 .unwrap_or_else(|_| panic!("entry has gained external references"));
232
233 trace!("Flushing {oldest_key:?}");
234 if let Err(err) = self.backend.flush(oldest_key, &evicted_object).await {
235 map.insert(
236 oldest_key,
237 AsyncLruCacheEntry {
238 value: Some(Arc::new(evicted_object)),
239 last_used: oldest_entry.last_used.load(Ordering::Relaxed).into(),
240 },
241 );
242 return Err(err);
243 }
244 }
245
246 Ok(true)
247 }
248
249 async fn get_or_insert(&self, key: K, may_flush: bool) -> io::Result<Option<Arc<V>>> {
261 {
262 let map = self.map.read().await;
263 if let Some(entry) = map.get(&key) {
264 entry.last_used.store(
265 self.lru_timer.fetch_add(1, Ordering::Relaxed),
266 Ordering::Relaxed,
267 );
268 return Ok(Some(Arc::clone(entry.value())));
269 }
270 }
271
272 let mut map = self.map.write().await;
273 if let Some(entry) = map.get(&key) {
274 entry.last_used.store(
275 self.lru_timer.fetch_add(1, Ordering::Relaxed),
276 Ordering::Relaxed,
277 );
278 return Ok(Some(Arc::clone(entry.value())));
279 }
280
281 if !self.ensure_free_entry(&mut map, may_flush).await? {
282 return Ok(None);
283 }
284
285 let object = Arc::new(self.backend.load(key).await?);
286
287 let new_entry = AsyncLruCacheEntry {
288 value: Some(Arc::clone(&object)),
289 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
290 };
291 map.insert(key, new_entry);
292
293 Ok(Some(object))
294 }
295
296 async fn insert(&self, key: K, value: Arc<V>, may_flush: bool) -> io::Result<bool> {
307 let mut map = self.map.write().await;
308 if let Some(entry) = map.get_mut(&key) {
309 if !may_flush {
310 return Ok(false);
311 }
312
313 entry.last_used.store(
314 self.lru_timer.fetch_add(1, Ordering::Relaxed),
315 Ordering::Relaxed,
316 );
317 self.backend.flush(key, entry.value()).await?;
318 entry.value = Some(value);
319 } else {
320 if !self.ensure_free_entry(&mut map, may_flush).await? {
321 return Ok(false);
322 }
323
324 let new_entry = AsyncLruCacheEntry {
325 value: Some(value),
326 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
327 };
328 map.insert(key, new_entry);
329 }
330
331 Ok(true)
332 }
333
334 #[instrument(
338 level = "trace",
339 name = "AsyncLruCache::flush",
340 skip_all,
341 fields(self = &self as *const _ as usize)
342 )]
343 async fn flush(&self) -> io::Result<()> {
344 let mut futs = FutureVector::new();
345
346 let map = self.map.read().await;
347 for (key, entry) in map.iter() {
348 let key = *key;
349 let object = Arc::clone(entry.value());
350 trace!("Flushing {key:?}");
351 futs.push(Box::pin(
352 async move { self.backend.flush(key, &object).await },
353 ));
354 }
355
356 let mut first_err = None;
357 while let Err(e) = futs.discarding_join().await {
358 first_err.get_or_insert(e);
359 }
360 if let Some(e) = first_err {
361 Err(e)
362 } else {
363 Ok(())
364 }
365 }
366
367 #[instrument(
375 level = "trace",
376 name = "AsyncLruCache::invalidate",
377 skip_all,
378 fields(self = &self as *const _ as usize)
379 )]
380 async unsafe fn invalidate(&self) -> io::Result<()> {
381 let mut in_use = Vec::new();
382
383 let mut map = self.map.write().await;
384 let old_map = mem::take(&mut *map);
387 for (key, mut entry) in old_map {
388 let object = entry.value.take().unwrap();
389 trace!("Evicting {key:?}");
390 match Arc::try_unwrap(object) {
391 Ok(object) => {
392 unsafe { self.backend.evict(key, object) };
394 }
395
396 Err(arc) => {
397 trace!("Entry is still in use, retaining it");
398 entry.value = Some(arc);
399 map.insert(key, entry);
400 in_use.push(key);
401 }
402 }
403 }
404
405 if in_use.is_empty() {
406 Ok(())
407 } else {
408 Err(io::Error::other(format!(
409 "Cannot invalidate cache, entries still in use: {}",
410 in_use
411 .iter()
412 .map(|key| format!("{key:?}"))
413 .collect::<Vec<String>>()
414 .join(", "),
415 )))
416 }
417 }
418}
419
420impl<V> AsyncLruCacheEntry<V> {
421 fn value(&self) -> &Arc<V> {
423 self.value.as_ref().unwrap()
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use std::sync::atomic::AtomicUsize;
431
432 struct DummyBackend;
434
435 impl AsyncLruCacheBackend for DummyBackend {
436 type Key = usize;
437 type Value = usize;
438
439 async fn load(&self, key: usize) -> io::Result<usize> {
440 Ok(key)
441 }
442
443 async fn flush(&self, _key: usize, _value: &usize) -> io::Result<()> {
444 Ok(())
445 }
446
447 unsafe fn evict(&self, _key: usize, _value: usize) {}
448 }
449
450 #[derive(Default)]
452 struct RecordingBackend {
453 flushed: std::sync::Mutex<Vec<(usize, usize)>>,
454 }
455
456 impl AsyncLruCacheBackend for RecordingBackend {
457 type Key = usize;
458 type Value = usize;
459
460 async fn load(&self, key: usize) -> io::Result<usize> {
461 Ok(key)
462 }
463
464 async fn flush(&self, key: usize, value: &usize) -> io::Result<()> {
465 self.flushed.lock().unwrap().push((key, *value));
466 Ok(())
467 }
468
469 unsafe fn evict(&self, _key: usize, _value: usize) {}
470 }
471
472 impl<B: AsyncLruCacheBackend> AsyncLruCacheBackend for Arc<B> {
473 type Key = <B as AsyncLruCacheBackend>::Key;
474 type Value = <B as AsyncLruCacheBackend>::Value;
475
476 async fn load(&self, key: Self::Key) -> io::Result<Self::Value> {
477 (**self).load(key).await
478 }
479
480 async fn flush(&self, key: Self::Key, value: &Self::Value) -> io::Result<()> {
481 (**self).flush(key, value).await
482 }
483
484 unsafe fn evict(&self, key: Self::Key, value: Self::Value) {
485 unsafe { (**self).evict(key, value) }
486 }
487 }
488
489 #[tokio::test]
492 async fn test_flush_continues_past_errors() {
493 #[derive(Default)]
494 struct FailOddBackend {
495 flush_count: AtomicUsize,
496 }
497
498 impl AsyncLruCacheBackend for FailOddBackend {
499 type Key = usize;
500 type Value = usize;
501
502 async fn load(&self, key: usize) -> io::Result<usize> {
503 Ok(key)
504 }
505
506 async fn flush(&self, key: usize, _value: &usize) -> io::Result<()> {
507 self.flush_count.fetch_add(1, Ordering::Relaxed);
508 if key % 2 == 1 {
509 Err(io::Error::other("odd key"))
510 } else {
511 Ok(())
512 }
513 }
514
515 unsafe fn evict(&self, _key: usize, _value: usize) {}
516 }
517
518 const ENTRIES: usize = 42;
519
520 let backend = Arc::new(FailOddBackend::default());
521 let cache = AsyncLruCache::new(Arc::clone(&backend), ENTRIES);
522
523 for i in 0..ENTRIES {
524 cache.get_or_insert(i, false).await.unwrap().unwrap();
525 }
526
527 let err = cache.flush().await.unwrap_err();
528 assert!(err.to_string().contains("odd key"));
529
530 assert_eq!(backend.flush_count.load(Ordering::Relaxed), ENTRIES);
531 }
532
533 #[tokio::test]
535 async fn test_lru_eviction_order() {
536 const ENTRIES: usize = 3;
537
538 let backend = Arc::new(RecordingBackend::default());
539 let cache = AsyncLruCache::new(Arc::clone(&backend), ENTRIES);
540
541 for i in 0..ENTRIES {
542 cache.get_or_insert(i, false).await.unwrap().unwrap();
543 }
544
545 cache.get_or_insert(0, false).await.unwrap().unwrap();
547
548 assert_eq!(cache.get_or_insert(ENTRIES, false).await.unwrap(), None);
550 cache.get_or_insert(ENTRIES, true).await.unwrap().unwrap();
551
552 assert_eq!(*backend.flushed.lock().unwrap(), [(1, 1)]);
553 }
554
555 #[tokio::test]
557 async fn test_in_use_entries_not_evicted() {
558 let backend = Arc::new(RecordingBackend::default());
559 let cache = AsyncLruCache::new(Arc::clone(&backend), 2);
560
561 let held = cache.get_or_insert(0, false).await.unwrap().unwrap();
562 cache.get_or_insert(1, false).await.unwrap().unwrap();
563
564 assert_eq!(cache.get_or_insert(2, false).await.unwrap(), None);
566 cache.get_or_insert(2, true).await.unwrap().unwrap();
567
568 assert_eq!(*backend.flushed.lock().unwrap(), [(1, 1)]);
569 assert_eq!(*held, 0);
570 }
571
572 #[tokio::test]
574 async fn test_cache_full_all_in_use() {
575 const ENTRIES: usize = 23;
576
577 let cache = AsyncLruCache::new(DummyBackend, ENTRIES);
578
579 let mut held = vec![];
580 for i in 0..ENTRIES {
581 held.push(cache.get_or_insert(i, false).await.unwrap().unwrap());
582 }
583
584 assert_eq!(cache.get_or_insert(ENTRIES, false).await.unwrap(), None);
585 let err = cache.get_or_insert(ENTRIES, true).await.unwrap_err();
586 assert!(err.to_string().contains("everything is in use"));
587 }
588
589 #[tokio::test]
591 async fn test_invalidate_retains_in_use() {
592 let cache = AsyncLruCache::new(DummyBackend, 16);
593
594 let held = cache.get_or_insert(0, false).await.unwrap().unwrap();
595 cache.get_or_insert(1, false).await.unwrap().unwrap();
596 cache.get_or_insert(2, false).await.unwrap().unwrap();
597
598 let err = unsafe { cache.invalidate() }.await.unwrap_err();
599 assert!(err.to_string().contains("still in use"));
600
601 let from_cache = cache.get_or_insert(0, false).await.unwrap().unwrap();
602 assert!(Arc::ptr_eq(&from_cache, &held));
603 let from_cache = cache.get_or_insert(0, true).await.unwrap().unwrap();
604 assert!(Arc::ptr_eq(&from_cache, &held));
605
606 assert_eq!(cache.0.map.read().await.len(), 1);
607 }
608
609 #[tokio::test]
611 async fn test_eviction_flush_failure_reinserts_entry() {
612 struct FailFlushBackend;
613
614 impl AsyncLruCacheBackend for FailFlushBackend {
615 type Key = usize;
616 type Value = usize;
617
618 async fn load(&self, key: usize) -> io::Result<usize> {
619 Ok(key)
620 }
621
622 async fn flush(&self, _key: usize, _value: &usize) -> io::Result<()> {
623 Err(io::Error::other("flush failed"))
624 }
625
626 unsafe fn evict(&self, _key: usize, _value: usize) {}
627 }
628
629 const ENTRIES: usize = 2;
630
631 let cache = AsyncLruCache::new(FailFlushBackend, ENTRIES);
632
633 for i in 0..ENTRIES {
634 cache.get_or_insert(i, false).await.unwrap().unwrap();
635 }
636
637 assert_eq!(cache.get_or_insert(ENTRIES, false).await.unwrap(), None);
639 let err = cache.get_or_insert(ENTRIES, true).await.unwrap_err();
641 assert!(err.to_string().contains("flush failed"));
642
643 assert_eq!(cache.0.map.read().await.len(), ENTRIES);
645 for i in 0..ENTRIES {
646 let entry = cache.get_or_insert(i, false).await.unwrap().unwrap();
647 assert_eq!(*entry, i);
648 }
649
650 assert_eq!(cache.get_or_insert(ENTRIES, false).await.unwrap(), None);
652 let err = cache.get_or_insert(ENTRIES, true).await.unwrap_err();
653 assert!(err.to_string().contains("flush failed"));
654 }
655
656 #[tokio::test]
658 async fn test_insert_flushes_existing() {
659 let backend = Arc::new(RecordingBackend::default());
660 let cache = AsyncLruCache::new(Arc::clone(&backend), 16);
661
662 cache.get_or_insert(5, false).await.unwrap().unwrap();
663 assert!(!cache.insert(5, Arc::new(55), false).await.unwrap());
664 assert!(cache.insert(5, Arc::new(55), true).await.unwrap());
665
666 assert_eq!(*backend.flushed.lock().unwrap(), [(5, 5)]);
667 assert_eq!(*cache.get_or_insert(5, false).await.unwrap().unwrap(), 55);
668 assert_eq!(*cache.get_or_insert(5, true).await.unwrap().unwrap(), 55);
669 assert_eq!(cache.0.map.read().await.len(), 1);
670 }
671}