1#![allow(dead_code)]
11
12use crate::vector_select::FutureVector;
13use async_trait::async_trait;
14use std::collections::HashMap;
15use std::fmt::Debug;
16use std::hash::Hash;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19use std::{io, mem};
20use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockWriteGuard};
21use tracing::{error, span, trace, Level};
22
23pub(crate) struct AsyncLruCacheEntry<V> {
25 value: Option<Arc<V>>,
29
30 last_used: AtomicUsize,
32}
33
34struct AsyncLruCacheInner<
36 Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
37 Value: Send + Sync,
38 IoBackend: AsyncLruCacheBackend<Key = Key, Value = Value>,
39> {
40 backend: IoBackend,
42
43 map: RwLock<HashMap<Key, AsyncLruCacheEntry<Value>>>,
45
46 flush_before: Mutex<Vec<Arc<dyn FlushableCache>>>,
48
49 lru_timer: AtomicUsize,
51
52 limit: usize,
54}
55
56pub(crate) struct AsyncLruCache<
63 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
64 V: Send + Sync,
65 B: AsyncLruCacheBackend<Key = K, Value = V>,
66>(Arc<AsyncLruCacheInner<K, V, B>>);
67
68#[async_trait(?Send)]
70trait FlushableCache: Send + Sync {
71 async fn flush(&self) -> io::Result<()>;
73
74 async fn check_circular(&self, other: &Arc<dyn FlushableCache>) -> bool;
78}
79
80pub(crate) trait AsyncLruCacheBackend: Send + Sync {
82 type Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync;
84 type Value: Send + Sync;
86
87 #[allow(async_fn_in_trait)] async fn load(&self, key: Self::Key) -> io::Result<Self::Value>;
90
91 #[allow(async_fn_in_trait)] async fn flush(&self, key: Self::Key, value: Arc<Self::Value>) -> io::Result<()>;
97
98 unsafe fn evict(&self, key: Self::Key, value: Self::Value);
108}
109
110impl<
111 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
112 V: Send + Sync,
113 B: AsyncLruCacheBackend<Key = K, Value = V>,
114 > AsyncLruCache<K, V, B>
115{
116 pub fn new(backend: B, size: usize) -> Self {
120 AsyncLruCache(Arc::new(AsyncLruCacheInner {
121 backend,
122 map: Default::default(),
123 flush_before: Default::default(),
124 lru_timer: AtomicUsize::new(0),
125 limit: size,
126 }))
127 }
128
129 pub async fn get_or_insert(&self, key: K) -> io::Result<Arc<V>> {
134 self.0.get_or_insert(key).await
135 }
136
137 pub async fn insert(&self, key: K, value: Arc<V>) -> io::Result<()> {
141 self.0.insert(key, value).await
142 }
143
144 pub async fn flush(&self) -> io::Result<()> {
148 self.0.flush().await
149 }
150
151 pub async unsafe fn invalidate(&self) -> io::Result<()> {
159 unsafe { self.0.invalidate() }.await
160 }
161}
162
163impl<
164 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
165 V: Send + Sync + 'static,
166 B: AsyncLruCacheBackend<Key = K, Value = V> + 'static,
167 > AsyncLruCache<K, V, B>
168{
169 pub async fn depend_on<
173 K2: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
174 V2: Send + Sync + 'static,
175 B2: AsyncLruCacheBackend<Key = K2, Value = V2> + 'static,
176 >(
177 &self,
178 other: &AsyncLruCache<K2, V2, B2>,
179 ) -> io::Result<()> {
180 let _span = span!(
181 Level::TRACE,
182 "AsyncLruCache::depend_on",
183 self = Arc::as_ptr(&self.0) as usize,
184 other = Arc::as_ptr(&other.0) as usize
185 )
186 .entered();
187
188 let cloned: Arc<AsyncLruCacheInner<K2, V2, B2>> = Arc::clone(&other.0);
189 let cloned: Arc<dyn FlushableCache> = cloned;
190
191 loop {
192 {
193 let mut locked = self.0.flush_before.lock().await;
194 if locked.iter().any(|x| Arc::ptr_eq(x, &cloned)) {
196 break;
197 }
198
199 let self_arc: Arc<AsyncLruCacheInner<K, V, B>> = Arc::clone(&self.0);
200 let self_arc: Arc<dyn FlushableCache> = self_arc;
201 if !other.0.check_circular(&self_arc).await {
202 trace!("No circular dependency, entering new dependency");
203 locked.push(cloned);
204 break;
205 }
206 }
207
208 trace!("Circular dependency detected, flushing other cache first");
209
210 other.0.flush().await?;
211 }
212
213 Ok(())
214 }
215}
216
217impl<
218 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
219 V: Send + Sync,
220 B: AsyncLruCacheBackend<Key = K, Value = V>,
221 > AsyncLruCacheInner<K, V, B>
222{
223 async fn flush_dependencies(
231 flush_before: &mut MutexGuard<'_, Vec<Arc<dyn FlushableCache>>>,
232 ) -> io::Result<()> {
233 let _span = span!(Level::TRACE, "AsyncLruCache::flush_dependencies").entered();
234
235 while let Some(dep) = flush_before.pop() {
236 trace!("Flushing dependency {:?}", Arc::as_ptr(&dep) as *const _);
237 if let Err(err) = dep.flush().await {
238 flush_before.push(dep);
239 return Err(err);
240 }
241 }
242 Ok(())
243 }
244
245 async fn ensure_free_entry(
249 &self,
250 map: &mut RwLockWriteGuard<'_, HashMap<K, AsyncLruCacheEntry<V>>>,
251 ) -> io::Result<()> {
252 let _span = span!(
253 Level::TRACE,
254 "AsyncLruCache::ensure_free_entry",
255 self = &self as *const _ as usize
256 )
257 .entered();
258
259 while map.len() >= self.limit {
260 trace!("{} / {} used", map.len(), self.limit);
261
262 let now = self.lru_timer.load(Ordering::Relaxed);
263 let (evicted_object, key, last_used) = loop {
264 let oldest = map.iter().fold((0, None), |oldest, (key, entry)| {
265 if Arc::strong_count(entry.value()) > 1 {
267 return oldest;
268 }
269
270 let age = now.wrapping_sub(entry.last_used.load(Ordering::Relaxed));
271 if age >= oldest.0 {
272 (age, Some(*key))
273 } else {
274 oldest
275 }
276 });
277
278 let Some(oldest_key) = oldest.1 else {
279 error!("Cannot evict entry from cache; everything is in use");
280 return Err(io::Error::other(
281 "Cannot evict entry from cache; everything is in use",
282 ));
283 };
284
285 trace!("Removing entry with key {oldest_key:?}, aged {}", oldest.0);
286
287 let mut oldest_entry = map.remove(&oldest_key).unwrap();
288 match Arc::try_unwrap(oldest_entry.value.take().unwrap()) {
289 Ok(object) => {
290 break (
291 object,
292 oldest_key,
293 oldest_entry.last_used.load(Ordering::Relaxed),
294 )
295 }
296 Err(arc) => {
297 trace!("Entry is still in use, retrying");
298
299 oldest_entry.value = Some(arc);
303 }
304 }
305 };
306
307 let mut dep_guard = self.flush_before.lock().await;
308 Self::flush_dependencies(&mut dep_guard).await?;
309 let obj = Arc::new(evicted_object);
310 trace!("Flushing {key:?}");
311 if let Err(err) = self.backend.flush(key, Arc::clone(&obj)).await {
312 map.insert(
313 key,
314 AsyncLruCacheEntry {
315 value: Some(obj),
316 last_used: last_used.into(),
317 },
318 );
319 return Err(err);
320 }
321 let _ = Arc::into_inner(obj).expect("flush() must not clone the object");
322 }
323
324 Ok(())
325 }
326
327 async fn get_or_insert(&self, key: K) -> io::Result<Arc<V>> {
332 {
333 let map = self.map.read().await;
334 if let Some(entry) = map.get(&key) {
335 entry.last_used.store(
336 self.lru_timer.fetch_add(1, Ordering::Relaxed),
337 Ordering::Relaxed,
338 );
339 return Ok(Arc::clone(entry.value()));
340 }
341 }
342
343 let mut map = self.map.write().await;
344 if let Some(entry) = map.get(&key) {
345 entry.last_used.store(
346 self.lru_timer.fetch_add(1, Ordering::Relaxed),
347 Ordering::Relaxed,
348 );
349 return Ok(Arc::clone(entry.value()));
350 }
351
352 self.ensure_free_entry(&mut map).await?;
353
354 let object = Arc::new(self.backend.load(key).await?);
355
356 let new_entry = AsyncLruCacheEntry {
357 value: Some(Arc::clone(&object)),
358 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
359 };
360 map.insert(key, new_entry);
361
362 Ok(object)
363 }
364
365 async fn insert(&self, key: K, value: Arc<V>) -> io::Result<()> {
369 let mut map = self.map.write().await;
370 if let Some(entry) = map.get_mut(&key) {
371 entry.last_used.store(
372 self.lru_timer.fetch_add(1, Ordering::Relaxed),
373 Ordering::Relaxed,
374 );
375 let mut dep_guard = self.flush_before.lock().await;
376 Self::flush_dependencies(&mut dep_guard).await?;
377 self.backend.flush(key, Arc::clone(entry.value())).await?;
378 entry.value = Some(value);
379 } else {
380 self.ensure_free_entry(&mut map).await?;
381
382 let new_entry = AsyncLruCacheEntry {
383 value: Some(value),
384 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
385 };
386 map.insert(key, new_entry);
387 }
388
389 Ok(())
390 }
391
392 async fn flush(&self) -> io::Result<()> {
396 let _span = span!(
397 Level::TRACE,
398 "AsyncLruCache::flush",
399 self = &self as *const _ as usize
400 )
401 .entered();
402
403 let mut futs = FutureVector::new();
404
405 let mut dep_guard = self.flush_before.lock().await;
406 Self::flush_dependencies(&mut dep_guard).await?;
407
408 let map = self.map.read().await;
409 for (key, entry) in map.iter() {
410 let key = *key;
411 let object = Arc::clone(entry.value());
412 trace!("Flushing {key:?}");
413 futs.push(Box::pin(self.backend.flush(key, object)));
414 }
415
416 futs.discarding_join().await
417 }
418
419 async unsafe fn invalidate(&self) -> io::Result<()> {
427 let _span = span!(
428 Level::TRACE,
429 "AsyncLruCache::invalidate",
430 self = &self as *const _ as usize
431 )
432 .entered();
433
434 let mut in_use = Vec::new();
435
436 let mut map = self.map.write().await;
437 let old_map = mem::take(&mut *map);
440 for (key, mut entry) in old_map {
441 let object = entry.value.take().unwrap();
442 trace!("Evicting {key:?}");
443 match Arc::try_unwrap(object) {
444 Ok(object) => {
445 unsafe { self.backend.evict(key, object) };
447 }
448
449 Err(arc) => {
450 trace!("Entry is still in use, retaining it");
451 entry.value = Some(arc);
452 map.insert(key, entry);
453 in_use.push(key);
454 }
455 }
456 }
457
458 if in_use.is_empty() {
459 self.flush_before.lock().await.clear();
460 Ok(())
461 } else {
462 Err(io::Error::other(format!(
463 "Cannot invalidate cache, entries still in use: {}",
464 in_use
465 .iter()
466 .map(|key| format!("{key:?}"))
467 .collect::<Vec<String>>()
468 .join(", "),
469 )))
470 }
471 }
472}
473
474impl<V> AsyncLruCacheEntry<V> {
475 fn value(&self) -> &Arc<V> {
477 self.value.as_ref().unwrap()
478 }
479}
480
481#[async_trait(?Send)]
482impl<
483 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
484 V: Send + Sync,
485 B: AsyncLruCacheBackend<Key = K, Value = V>,
486 > FlushableCache for AsyncLruCacheInner<K, V, B>
487{
488 async fn flush(&self) -> io::Result<()> {
489 AsyncLruCacheInner::<K, V, B>::flush(self).await
490 }
491
492 async fn check_circular(&self, other: &Arc<dyn FlushableCache>) -> bool {
493 let deps = self.flush_before.lock().await;
494 for dep in deps.iter() {
495 if Arc::ptr_eq(dep, other) {
496 return true;
497 }
498 }
499 false
500 }
501}