1#![allow(dead_code)]
11
12use crate::vector_select::FutureVector;
13use async_trait::async_trait;
14use std::collections::HashMap;
15use std::fmt::Debug;
16use std::hash::Hash;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19use std::{io, mem};
20use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockWriteGuard};
21use tracing::{error, instrument, trace};
22
23pub(crate) struct AsyncLruCacheEntry<V> {
25 value: Option<Arc<V>>,
29
30 last_used: AtomicUsize,
32}
33
34struct AsyncLruCacheInner<
36 Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
37 Value: Send + Sync,
38 IoBackend: AsyncLruCacheBackend<Key = Key, Value = Value>,
39> {
40 backend: IoBackend,
42
43 map: RwLock<HashMap<Key, AsyncLruCacheEntry<Value>>>,
45
46 flush_before: Mutex<Vec<Arc<dyn FlushableCache>>>,
48
49 lru_timer: AtomicUsize,
51
52 limit: usize,
54}
55
56pub(crate) struct AsyncLruCache<
63 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
64 V: Send + Sync,
65 B: AsyncLruCacheBackend<Key = K, Value = V>,
66>(Arc<AsyncLruCacheInner<K, V, B>>);
67
68#[async_trait(?Send)]
70trait FlushableCache: Send + Sync {
71 async fn flush(&self) -> io::Result<()>;
73
74 async fn check_circular(&self, other: &Arc<dyn FlushableCache>) -> bool;
78}
79
80pub(crate) trait AsyncLruCacheBackend: Send + Sync {
82 type Key: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync;
84 type Value: Send + Sync;
86
87 #[allow(async_fn_in_trait)] async fn load(&self, key: Self::Key) -> io::Result<Self::Value>;
90
91 #[allow(async_fn_in_trait)] async fn flush(&self, key: Self::Key, value: Arc<Self::Value>) -> io::Result<()>;
97
98 unsafe fn evict(&self, key: Self::Key, value: Self::Value);
108}
109
110impl<
111 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
112 V: Send + Sync,
113 B: AsyncLruCacheBackend<Key = K, Value = V>,
114 > AsyncLruCache<K, V, B>
115{
116 pub fn new(backend: B, size: usize) -> Self {
120 AsyncLruCache(Arc::new(AsyncLruCacheInner {
121 backend,
122 map: Default::default(),
123 flush_before: Default::default(),
124 lru_timer: AtomicUsize::new(0),
125 limit: size,
126 }))
127 }
128
129 pub async fn get_or_insert(&self, key: K) -> io::Result<Arc<V>> {
134 self.0.get_or_insert(key).await
135 }
136
137 pub async fn insert(&self, key: K, value: Arc<V>) -> io::Result<()> {
141 self.0.insert(key, value).await
142 }
143
144 pub async fn flush(&self) -> io::Result<()> {
148 self.0.flush().await
149 }
150
151 pub async unsafe fn invalidate(&self) -> io::Result<()> {
159 unsafe { self.0.invalidate() }.await
160 }
161}
162
163impl<
164 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
165 V: Send + Sync + 'static,
166 B: AsyncLruCacheBackend<Key = K, Value = V> + 'static,
167 > AsyncLruCache<K, V, B>
168{
169 #[instrument(
173 level = "trace",
174 name = "AsyncLruCache::depend_on",
175 skip_all,
176 fields(
177 self = Arc::as_ptr(&self.0) as usize,
178 other = Arc::as_ptr(&other.0) as usize,
179 )
180 )]
181 pub async fn depend_on<
182 K2: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync + 'static,
183 V2: Send + Sync + 'static,
184 B2: AsyncLruCacheBackend<Key = K2, Value = V2> + 'static,
185 >(
186 &self,
187 other: &AsyncLruCache<K2, V2, B2>,
188 ) -> io::Result<()> {
189 let cloned: Arc<AsyncLruCacheInner<K2, V2, B2>> = Arc::clone(&other.0);
190 let cloned: Arc<dyn FlushableCache> = cloned;
191
192 loop {
193 {
194 let mut locked = self.0.flush_before.lock().await;
195 if locked.iter().any(|x| Arc::ptr_eq(x, &cloned)) {
197 break;
198 }
199
200 let self_arc: Arc<AsyncLruCacheInner<K, V, B>> = Arc::clone(&self.0);
201 let self_arc: Arc<dyn FlushableCache> = self_arc;
202 if !other.0.check_circular(&self_arc).await {
203 trace!("No circular dependency, entering new dependency");
204 locked.push(cloned);
205 break;
206 }
207 }
208
209 trace!("Circular dependency detected, flushing other cache first");
210
211 other.0.flush().await?;
212 }
213
214 Ok(())
215 }
216}
217
218impl<
219 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
220 V: Send + Sync,
221 B: AsyncLruCacheBackend<Key = K, Value = V>,
222 > AsyncLruCacheInner<K, V, B>
223{
224 #[instrument(level = "trace", name = "AsyncLruCache::flush_dependencies", skip_all)]
232 async fn flush_dependencies(
233 flush_before: &mut MutexGuard<'_, Vec<Arc<dyn FlushableCache>>>,
234 ) -> io::Result<()> {
235 while let Some(dep) = flush_before.pop() {
236 trace!("Flushing dependency {:?}", Arc::as_ptr(&dep) as *const _);
237 if let Err(err) = dep.flush().await {
238 flush_before.push(dep);
239 return Err(err);
240 }
241 }
242 Ok(())
243 }
244
245 #[instrument(
249 level = "trace",
250 name = "AsyncLruCache::ensure_free_entry",
251 skip_all,
252 fields(self = &self as *const _ as usize),
253 )]
254 async fn ensure_free_entry(
255 &self,
256 map: &mut RwLockWriteGuard<'_, HashMap<K, AsyncLruCacheEntry<V>>>,
257 ) -> io::Result<()> {
258 while map.len() >= self.limit {
259 trace!("{} / {} used", map.len(), self.limit);
260
261 let now = self.lru_timer.load(Ordering::Relaxed);
262 let (evicted_object, key, last_used) = loop {
263 let oldest = map.iter().fold((0, None), |oldest, (key, entry)| {
264 if Arc::strong_count(entry.value()) > 1 {
266 return oldest;
267 }
268
269 let age = now.wrapping_sub(entry.last_used.load(Ordering::Relaxed));
270 if age >= oldest.0 {
271 (age, Some(*key))
272 } else {
273 oldest
274 }
275 });
276
277 let Some(oldest_key) = oldest.1 else {
278 error!("Cannot evict entry from cache; everything is in use");
279 return Err(io::Error::other(
280 "Cannot evict entry from cache; everything is in use",
281 ));
282 };
283
284 trace!("Removing entry with key {oldest_key:?}, aged {}", oldest.0);
285
286 let mut oldest_entry = map.remove(&oldest_key).unwrap();
287 match Arc::try_unwrap(oldest_entry.value.take().unwrap()) {
288 Ok(object) => {
289 break (
290 object,
291 oldest_key,
292 oldest_entry.last_used.load(Ordering::Relaxed),
293 )
294 }
295 Err(arc) => {
296 trace!("Entry is still in use, retrying");
297
298 oldest_entry.value = Some(arc);
302 }
303 }
304 };
305
306 let mut dep_guard = self.flush_before.lock().await;
307 Self::flush_dependencies(&mut dep_guard).await?;
308 let obj = Arc::new(evicted_object);
309 trace!("Flushing {key:?}");
310 if let Err(err) = self.backend.flush(key, Arc::clone(&obj)).await {
311 map.insert(
312 key,
313 AsyncLruCacheEntry {
314 value: Some(obj),
315 last_used: last_used.into(),
316 },
317 );
318 return Err(err);
319 }
320 let _ = Arc::into_inner(obj).expect("flush() must not clone the object");
321 }
322
323 Ok(())
324 }
325
326 async fn get_or_insert(&self, key: K) -> io::Result<Arc<V>> {
331 {
332 let map = self.map.read().await;
333 if let Some(entry) = map.get(&key) {
334 entry.last_used.store(
335 self.lru_timer.fetch_add(1, Ordering::Relaxed),
336 Ordering::Relaxed,
337 );
338 return Ok(Arc::clone(entry.value()));
339 }
340 }
341
342 let mut map = self.map.write().await;
343 if let Some(entry) = map.get(&key) {
344 entry.last_used.store(
345 self.lru_timer.fetch_add(1, Ordering::Relaxed),
346 Ordering::Relaxed,
347 );
348 return Ok(Arc::clone(entry.value()));
349 }
350
351 self.ensure_free_entry(&mut map).await?;
352
353 let object = Arc::new(self.backend.load(key).await?);
354
355 let new_entry = AsyncLruCacheEntry {
356 value: Some(Arc::clone(&object)),
357 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
358 };
359 map.insert(key, new_entry);
360
361 Ok(object)
362 }
363
364 async fn insert(&self, key: K, value: Arc<V>) -> io::Result<()> {
368 let mut map = self.map.write().await;
369 if let Some(entry) = map.get_mut(&key) {
370 entry.last_used.store(
371 self.lru_timer.fetch_add(1, Ordering::Relaxed),
372 Ordering::Relaxed,
373 );
374 let mut dep_guard = self.flush_before.lock().await;
375 Self::flush_dependencies(&mut dep_guard).await?;
376 self.backend.flush(key, Arc::clone(entry.value())).await?;
377 entry.value = Some(value);
378 } else {
379 self.ensure_free_entry(&mut map).await?;
380
381 let new_entry = AsyncLruCacheEntry {
382 value: Some(value),
383 last_used: AtomicUsize::new(self.lru_timer.fetch_add(1, Ordering::Relaxed)),
384 };
385 map.insert(key, new_entry);
386 }
387
388 Ok(())
389 }
390
391 #[instrument(
395 level = "trace",
396 name = "AsyncLruCache::flush",
397 skip_all,
398 fields(self = &self as *const _ as usize)
399 )]
400 async fn flush(&self) -> io::Result<()> {
401 let mut futs = FutureVector::new();
402
403 let mut dep_guard = self.flush_before.lock().await;
404 Self::flush_dependencies(&mut dep_guard).await?;
405
406 let map = self.map.read().await;
407 for (key, entry) in map.iter() {
408 let key = *key;
409 let object = Arc::clone(entry.value());
410 trace!("Flushing {key:?}");
411 futs.push(Box::pin(self.backend.flush(key, object)));
412 }
413
414 futs.discarding_join().await
415 }
416
417 #[instrument(
425 level = "trace",
426 name = "AsyncLruCache::invalidate",
427 skip_all,
428 fields(self = &self as *const _ as usize)
429 )]
430 async unsafe fn invalidate(&self) -> io::Result<()> {
431 let mut in_use = Vec::new();
432
433 let mut map = self.map.write().await;
434 let old_map = mem::take(&mut *map);
437 for (key, mut entry) in old_map {
438 let object = entry.value.take().unwrap();
439 trace!("Evicting {key:?}");
440 match Arc::try_unwrap(object) {
441 Ok(object) => {
442 unsafe { self.backend.evict(key, object) };
444 }
445
446 Err(arc) => {
447 trace!("Entry is still in use, retaining it");
448 entry.value = Some(arc);
449 map.insert(key, entry);
450 in_use.push(key);
451 }
452 }
453 }
454
455 if in_use.is_empty() {
456 self.flush_before.lock().await.clear();
457 Ok(())
458 } else {
459 Err(io::Error::other(format!(
460 "Cannot invalidate cache, entries still in use: {}",
461 in_use
462 .iter()
463 .map(|key| format!("{key:?}"))
464 .collect::<Vec<String>>()
465 .join(", "),
466 )))
467 }
468 }
469}
470
471impl<V> AsyncLruCacheEntry<V> {
472 fn value(&self) -> &Arc<V> {
474 self.value.as_ref().unwrap()
475 }
476}
477
478#[async_trait(?Send)]
479impl<
480 K: Clone + Copy + Debug + PartialEq + Eq + Hash + Send + Sync,
481 V: Send + Sync,
482 B: AsyncLruCacheBackend<Key = K, Value = V>,
483 > FlushableCache for AsyncLruCacheInner<K, V, B>
484{
485 async fn flush(&self) -> io::Result<()> {
486 AsyncLruCacheInner::<K, V, B>::flush(self).await
487 }
488
489 async fn check_circular(&self, other: &Arc<dyn FlushableCache>) -> bool {
490 let deps = self.flush_before.lock().await;
491 for dep in deps.iter() {
492 if Arc::ptr_eq(dep, other) {
493 return true;
494 }
495 }
496 false
497 }
498}