imago/storage/
drivers.rs

1//! Internal functionality for storage drivers.
2
3use crate::misc_helpers::Overlaps;
4use crate::vector_select::FutureVector;
5use std::ops::Range;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::Arc;
8use tokio::sync::oneshot;
9
10/// Helper object for the [`StorageExt`](crate::StorageExt) implementation.
11///
12/// State such as write blockers needs to be kept somewhere, and instead of introducing a wrapper
13/// (that might be bypassed), we store it directly in the [`Storage`](crate::Storage) objects so it
14/// cannot be bypassed (at least when using the [`StorageExt`](crate::StorageExt) methods).
15#[derive(Debug, Default)]
16pub struct CommonStorageHelper {
17    /// Current in-flight write that allow concurrent writes to the same region.
18    ///
19    /// Normal non-async RwLock, so do not await while locked!
20    weak_write_blockers: std::sync::RwLock<RangeBlockedList>,
21
22    /// Current in-flight write that do not allow concurrent writes to the same region.
23    strong_write_blockers: std::sync::RwLock<RangeBlockedList>,
24}
25
26/// A list of ranges blocked for some kind of concurrent access.
27///
28/// Depending on the use, some will block all concurrent access (i.e. serializing writes will block
29/// both serializing and non-serializing writes (strong blockers)), while others will only block a
30/// subset (non-serializing writes will only block serializing writes (weak blockers)).
31#[derive(Debug, Default)]
32struct RangeBlockedList {
33    /// The list of ranges.
34    ///
35    /// Serializing writes (strong write blockers) are supposed to be rare, so it is important that
36    /// entering and removing items into/from this list is cheap, not that iterating it is.
37    blocked: Vec<Arc<RangeBlocked>>,
38}
39
40/// A range blocked for some kind of concurrent access.
41#[derive(Debug)]
42struct RangeBlocked {
43    /// The range.
44    range: Range<u64>,
45
46    /// List of requests awaiting the range to become unblocked.
47    ///
48    /// When the corresponding `RangeBlockedGuard` is dropped, these will all be awoken (via
49    /// `oneshot::Sender::send(())`).
50    ///
51    /// Normal non-async mutex, so do not await while locked!
52    waitlist: std::sync::Mutex<Vec<oneshot::Sender<()>>>,
53
54    /// Index in the corresponding `RangeBlockedList.blocked` list, so it can be dropped quickly.
55    ///
56    /// (When the corresponding `RangeBlockedGuard` is dropped, this entry is swap-removed from the
57    /// `blocked` list, and the other entry taking its place has its `index` updated.)
58    ///
59    /// Only access under `blocked` lock!
60    index: AtomicUsize,
61}
62
63/// Keeps a `RangeBlocked` alive.
64///
65/// When dropped, removes the `RangeBlocked` from its list, and wakes all requests in the `waitlist`.
66#[derive(Debug)]
67pub struct RangeBlockedGuard<'a> {
68    /// List where this blocker resides.
69    list: &'a std::sync::RwLock<RangeBlockedList>,
70
71    /// `Option`, so `drop()` can `take()` it and unwrap the `Arc`.
72    ///
73    /// Consequently, do not clone: Must have refcount 1 when dropped.  (The only clone must be in
74    /// `self.list.blocked`, under index `self.block.index`.)
75    block: Option<Arc<RangeBlocked>>,
76}
77
78impl CommonStorageHelper {
79    /// Await concurrent strong write blockers for the given range.
80    ///
81    /// Strong write blockers are set up for writes that must not be intersected by any other
82    /// write.  Await such intersecting concurrent write requests, and return a guard that will
83    /// delay such new writes until the guard is dropped.
84    pub async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
85        let mut intersecting = FutureVector::new();
86
87        let range_block = {
88            // Acquire write lock first
89            let mut weak = self.weak_write_blockers.write().unwrap();
90            let strong = self.strong_write_blockers.read().unwrap();
91
92            strong.collect_intersecting_await_futures(&range, &mut intersecting);
93            weak.block(range)
94        };
95
96        intersecting.discarding_join().await.unwrap();
97
98        RangeBlockedGuard {
99            list: &self.weak_write_blockers,
100            block: Some(range_block),
101        }
102    }
103
104    /// Await any concurrent write request for the given range.
105    ///
106    /// Block the given range for any concurrent write requests until the returned guard object is
107    /// dropped.  Existing requests are awaited, and new ones will be delayed.
108    pub async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
109        let mut intersecting = FutureVector::new();
110
111        let range_block = {
112            // Acquire write lock first
113            let mut strong = self.strong_write_blockers.write().unwrap();
114            let weak = self.weak_write_blockers.read().unwrap();
115
116            weak.collect_intersecting_await_futures(&range, &mut intersecting);
117            strong.collect_intersecting_await_futures(&range, &mut intersecting);
118            strong.block(range)
119        };
120
121        intersecting.discarding_join().await.unwrap();
122
123        RangeBlockedGuard {
124            list: &self.strong_write_blockers,
125            block: Some(range_block),
126        }
127    }
128}
129
130impl RangeBlockedList {
131    /// Collects futures to await intersecting request.
132    ///
133    /// Adds a future to `future_vector` for every intersecting request; awaiting that future will
134    /// await the request.
135    fn collect_intersecting_await_futures(
136        &self,
137        check_range: &Range<u64>,
138        future_vector: &mut FutureVector<(), oneshot::error::RecvError, oneshot::Receiver<()>>,
139    ) {
140        for range_block in self.blocked.iter() {
141            if range_block.range.overlaps(check_range) {
142                let (s, r) = oneshot::channel::<()>();
143                range_block.waitlist.lock().unwrap().push(s);
144                future_vector.push(r);
145            }
146        }
147    }
148
149    /// Enter a new blocked range into the list.
150    ///
151    /// This only blocks new requests, old requests must separately be awaited by awaiting all
152    /// futures returned by `collect_intersecting_await_futures()`.
153    fn block(&mut self, range: Range<u64>) -> Arc<RangeBlocked> {
154        let range_block = Arc::new(RangeBlocked {
155            range,
156            waitlist: Default::default(),
157            index: self.blocked.len().into(),
158        });
159        self.blocked.push(Arc::clone(&range_block));
160        range_block
161    }
162}
163
164impl Drop for RangeBlockedGuard<'_> {
165    fn drop(&mut self) {
166        let block = self.block.take().unwrap();
167
168        {
169            let mut list = self.list.write().unwrap();
170            let i = block.index.load(Ordering::Relaxed);
171            let removed = list.blocked.swap_remove(i);
172            debug_assert!(Arc::ptr_eq(&removed, &block));
173            if let Some(block) = list.blocked.get(i) {
174                block.index.store(i, Ordering::Relaxed);
175            }
176        }
177
178        let block = Arc::into_inner(block).unwrap();
179        let waitlist = block.waitlist.into_inner().unwrap();
180        for waiting in waitlist {
181            waiting.send(()).unwrap();
182        }
183    }
184}