imago/storage/
drivers.rs

1//! Internal functionality for storage drivers.
2
3use crate::misc_helpers::Overlaps;
4use crate::vector_select::FutureVector;
5use std::ops::Range;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::Arc;
8use tokio::sync::oneshot;
9
10/// Helper object for the [`StorageExt`](crate::StorageExt) implementation.
11///
12/// State such as write blockers needs to be kept somewhere, and instead of introducing a wrapper
13/// (that might be bypassed), we store it directly in the [`Storage`](crate::Storage) objects so it
14/// cannot be bypassed (at least when using the [`StorageExt`](crate::StorageExt) methods).
15#[derive(Debug, Default)]
16pub struct CommonStorageHelper {
17    /// Current in-flight write that allow concurrent writes to the same region.
18    ///
19    /// Normal non-async RwLock, so do not await while locked!
20    weak_write_blockers: std::sync::RwLock<RangeBlockedList>,
21
22    /// Current in-flight write that do not allow concurrent writes to the same region.
23    strong_write_blockers: std::sync::RwLock<RangeBlockedList>,
24}
25
26/// A list of ranges blocked for some kind of concurrent access.
27///
28/// Depending on the use, some will block all concurrent access (i.e. serializing writes will block
29/// both serializing and non-serializing writes (strong blockers)), while others will only block a
30/// subset (non-serializing writes will only block serializing writes (weak blockers)).
31#[derive(Debug, Default)]
32struct RangeBlockedList {
33    /// The list of ranges.
34    ///
35    /// Serializing writes (strong write blockers) are supposed to be rare, so it is important that
36    /// entering and removing items into/from this list is cheap, not that iterating it is.
37    blocked: Vec<Arc<RangeBlocked>>,
38}
39
40/// A range blocked for some kind of concurrent access.
41#[derive(Debug)]
42struct RangeBlocked {
43    /// The range.
44    range: Range<u64>,
45
46    /// List of requests awaiting the range to become unblocked.
47    ///
48    /// When the corresponding `RangeBlockedGuard` is dropped, these will all be awoken (via
49    /// `oneshot::Sender::send(())`).
50    ///
51    /// Normal non-async mutex, so do not await while locked!
52    waitlist: std::sync::Mutex<Vec<oneshot::Sender<()>>>,
53
54    /// Index in the corresponding `RangeBlockedList.blocked` list, so it can be dropped quickly.
55    ///
56    /// (When the corresponding `RangeBlockedGuard` is dropped, this entry is swap-removed from the
57    /// `blocked` list, and the other entry taking its place has its `index` updated.)
58    ///
59    /// Only access under `blocked` lock!
60    index: AtomicUsize,
61}
62
63/// Keeps a `RangeBlocked` alive.
64///
65/// When dropped, removes the `RangeBlocked` from its list, and wakes all requests in the `waitlist`.
66#[derive(Debug)]
67pub struct RangeBlockedGuard<'a> {
68    /// List where this blocker resides.
69    list: &'a std::sync::RwLock<RangeBlockedList>,
70
71    /// `Option`, so `drop()` can `take()` it and unwrap the `Arc`.
72    ///
73    /// Consequently, do not clone: Must have refcount 1 when dropped.  (The only clone must be in
74    /// `self.list.blocked`, under index `self.block.index`.)
75    block: Option<Arc<RangeBlocked>>,
76}
77
78impl CommonStorageHelper {
79    /// Await concurrent strong write blockers for the given range.
80    ///
81    /// Strong write blockers are set up for writes that must not be intersected by any other
82    /// write.  Await such intersecting concurrent write requests, and return a guard that will
83    /// delay such new writes until the guard is dropped.
84    pub async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
85        let mut intersecting = FutureVector::new();
86
87        // Create `RangeBlockedGuard` before the `await` below, so if the future is dropped,
88        // `RangeBlockedGuard::drop()` will run, removing the blocker from the list
89        let guard = {
90            // Consistent ordering to avoid deadlock: Always acquire weak before strong
91            let mut weak = self.weak_write_blockers.write().unwrap();
92            let strong = self.strong_write_blockers.read().unwrap();
93
94            strong.collect_intersecting_await_futures(&range, &mut intersecting);
95
96            RangeBlockedGuard {
97                list: &self.weak_write_blockers,
98                block: Some(weak.block(range)),
99            }
100        };
101
102        // `RecvError` means the blocker's guard was dropped without signaling, so the blocking
103        // operation is gone, and thus waiting for it is pointless.  We must still wait for all
104        // other overlapping blockers, so drain until all are actually done, ignoring errors.
105        while intersecting.discarding_join().await.is_err() {}
106
107        guard
108    }
109
110    /// Await any concurrent write request for the given range.
111    ///
112    /// Block the given range for any concurrent write requests until the returned guard object is
113    /// dropped.  Existing requests are awaited, and new ones will be delayed.
114    pub async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
115        let mut intersecting = FutureVector::new();
116
117        // Create `RangeBlockedGuard` before the `await` below, so if the future is dropped,
118        // `RangeBlockedGuard::drop()` will run, removing the blocker from the list
119        let guard = {
120            // Consistent ordering to avoid deadlock: Always acquire weak before strong
121            let weak = self.weak_write_blockers.read().unwrap();
122            let mut strong = self.strong_write_blockers.write().unwrap();
123
124            weak.collect_intersecting_await_futures(&range, &mut intersecting);
125            strong.collect_intersecting_await_futures(&range, &mut intersecting);
126
127            RangeBlockedGuard {
128                list: &self.strong_write_blockers,
129                block: Some(strong.block(range)),
130            }
131        };
132
133        // `RecvError` means the blocker's guard was dropped without signaling, so the blocking
134        // operation is gone, and thus waiting for it is pointless.  We must still wait for all
135        // other overlapping blockers, so drain until all are actually done, ignoring errors.
136        while intersecting.discarding_join().await.is_err() {}
137
138        guard
139    }
140}
141
142impl RangeBlockedList {
143    /// Collects futures to await intersecting request.
144    ///
145    /// Adds a future to `future_vector` for every intersecting request; awaiting that future will
146    /// await the request.
147    fn collect_intersecting_await_futures(
148        &self,
149        check_range: &Range<u64>,
150        future_vector: &mut FutureVector<(), oneshot::error::RecvError, oneshot::Receiver<()>>,
151    ) {
152        for range_block in self.blocked.iter() {
153            if range_block.range.overlaps(check_range) {
154                let (s, r) = oneshot::channel::<()>();
155                range_block.waitlist.lock().unwrap().push(s);
156                future_vector.push(r);
157            }
158        }
159    }
160
161    /// Enter a new blocked range into the list.
162    ///
163    /// This only blocks new requests, old requests must separately be awaited by awaiting all
164    /// futures returned by `collect_intersecting_await_futures()`.
165    fn block(&mut self, range: Range<u64>) -> Arc<RangeBlocked> {
166        let range_block = Arc::new(RangeBlocked {
167            range,
168            waitlist: Default::default(),
169            index: self.blocked.len().into(),
170        });
171        self.blocked.push(Arc::clone(&range_block));
172        range_block
173    }
174}
175
176impl Drop for RangeBlockedGuard<'_> {
177    fn drop(&mut self) {
178        let block = self.block.take().unwrap();
179
180        {
181            let mut list = self.list.write().unwrap();
182            let i = block.index.load(Ordering::Relaxed);
183            let removed = list.blocked.swap_remove(i);
184            debug_assert!(Arc::ptr_eq(&removed, &block));
185            if let Some(block) = list.blocked.get(i) {
186                block.index.store(i, Ordering::Relaxed);
187            }
188        }
189
190        let block = Arc::into_inner(block).unwrap();
191        let waitlist = block.waitlist.into_inner().unwrap();
192        for waiting in waitlist {
193            // If the receiving end was dropped (e.g. because the request was dropped), then just
194            // ignore that
195            let _ = waiting.send(());
196        }
197    }
198}