imago/storage/drivers.rs
1//! Internal functionality for storage drivers.
2
3use crate::misc_helpers::Overlaps;
4use crate::vector_select::FutureVector;
5use std::ops::Range;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::sync::Arc;
8use tokio::sync::oneshot;
9
10/// Helper object for the [`StorageExt`](crate::StorageExt) implementation.
11///
12/// State such as write blockers needs to be kept somewhere, and instead of introducing a wrapper
13/// (that might be bypassed), we store it directly in the [`Storage`](crate::Storage) objects so it
14/// cannot be bypassed (at least when using the [`StorageExt`](crate::StorageExt) methods).
15#[derive(Debug, Default)]
16pub struct CommonStorageHelper {
17 /// Current in-flight write that allow concurrent writes to the same region.
18 ///
19 /// Normal non-async RwLock, so do not await while locked!
20 weak_write_blockers: std::sync::RwLock<RangeBlockedList>,
21
22 /// Current in-flight write that do not allow concurrent writes to the same region.
23 strong_write_blockers: std::sync::RwLock<RangeBlockedList>,
24}
25
26/// A list of ranges blocked for some kind of concurrent access.
27///
28/// Depending on the use, some will block all concurrent access (i.e. serializing writes will block
29/// both serializing and non-serializing writes (strong blockers)), while others will only block a
30/// subset (non-serializing writes will only block serializing writes (weak blockers)).
31#[derive(Debug, Default)]
32struct RangeBlockedList {
33 /// The list of ranges.
34 ///
35 /// Serializing writes (strong write blockers) are supposed to be rare, so it is important that
36 /// entering and removing items into/from this list is cheap, not that iterating it is.
37 blocked: Vec<Arc<RangeBlocked>>,
38}
39
40/// A range blocked for some kind of concurrent access.
41#[derive(Debug)]
42struct RangeBlocked {
43 /// The range.
44 range: Range<u64>,
45
46 /// List of requests awaiting the range to become unblocked.
47 ///
48 /// When the corresponding `RangeBlockedGuard` is dropped, these will all be awoken (via
49 /// `oneshot::Sender::send(())`).
50 ///
51 /// Normal non-async mutex, so do not await while locked!
52 waitlist: std::sync::Mutex<Vec<oneshot::Sender<()>>>,
53
54 /// Index in the corresponding `RangeBlockedList.blocked` list, so it can be dropped quickly.
55 ///
56 /// (When the corresponding `RangeBlockedGuard` is dropped, this entry is swap-removed from the
57 /// `blocked` list, and the other entry taking its place has its `index` updated.)
58 ///
59 /// Only access under `blocked` lock!
60 index: AtomicUsize,
61}
62
63/// Keeps a `RangeBlocked` alive.
64///
65/// When dropped, removes the `RangeBlocked` from its list, and wakes all requests in the `waitlist`.
66#[derive(Debug)]
67pub struct RangeBlockedGuard<'a> {
68 /// List where this blocker resides.
69 list: &'a std::sync::RwLock<RangeBlockedList>,
70
71 /// `Option`, so `drop()` can `take()` it and unwrap the `Arc`.
72 ///
73 /// Consequently, do not clone: Must have refcount 1 when dropped. (The only clone must be in
74 /// `self.list.blocked`, under index `self.block.index`.)
75 block: Option<Arc<RangeBlocked>>,
76}
77
78impl CommonStorageHelper {
79 /// Await concurrent strong write blockers for the given range.
80 ///
81 /// Strong write blockers are set up for writes that must not be intersected by any other
82 /// write. Await such intersecting concurrent write requests, and return a guard that will
83 /// delay such new writes until the guard is dropped.
84 pub async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
85 let mut intersecting = FutureVector::new();
86
87 let range_block = {
88 // Acquire write lock first
89 let mut weak = self.weak_write_blockers.write().unwrap();
90 let strong = self.strong_write_blockers.read().unwrap();
91
92 strong.collect_intersecting_await_futures(&range, &mut intersecting);
93 weak.block(range)
94 };
95
96 intersecting.discarding_join().await.unwrap();
97
98 RangeBlockedGuard {
99 list: &self.weak_write_blockers,
100 block: Some(range_block),
101 }
102 }
103
104 /// Await any concurrent write request for the given range.
105 ///
106 /// Block the given range for any concurrent write requests until the returned guard object is
107 /// dropped. Existing requests are awaited, and new ones will be delayed.
108 pub async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
109 let mut intersecting = FutureVector::new();
110
111 let range_block = {
112 // Acquire write lock first
113 let mut strong = self.strong_write_blockers.write().unwrap();
114 let weak = self.weak_write_blockers.read().unwrap();
115
116 weak.collect_intersecting_await_futures(&range, &mut intersecting);
117 strong.collect_intersecting_await_futures(&range, &mut intersecting);
118 strong.block(range)
119 };
120
121 intersecting.discarding_join().await.unwrap();
122
123 RangeBlockedGuard {
124 list: &self.strong_write_blockers,
125 block: Some(range_block),
126 }
127 }
128}
129
130impl RangeBlockedList {
131 /// Collects futures to await intersecting request.
132 ///
133 /// Adds a future to `future_vector` for every intersecting request; awaiting that future will
134 /// await the request.
135 fn collect_intersecting_await_futures(
136 &self,
137 check_range: &Range<u64>,
138 future_vector: &mut FutureVector<(), oneshot::error::RecvError, oneshot::Receiver<()>>,
139 ) {
140 for range_block in self.blocked.iter() {
141 if range_block.range.overlaps(check_range) {
142 let (s, r) = oneshot::channel::<()>();
143 range_block.waitlist.lock().unwrap().push(s);
144 future_vector.push(r);
145 }
146 }
147 }
148
149 /// Enter a new blocked range into the list.
150 ///
151 /// This only blocks new requests, old requests must separately be awaited by awaiting all
152 /// futures returned by `collect_intersecting_await_futures()`.
153 fn block(&mut self, range: Range<u64>) -> Arc<RangeBlocked> {
154 let range_block = Arc::new(RangeBlocked {
155 range,
156 waitlist: Default::default(),
157 index: self.blocked.len().into(),
158 });
159 self.blocked.push(Arc::clone(&range_block));
160 range_block
161 }
162}
163
164impl Drop for RangeBlockedGuard<'_> {
165 fn drop(&mut self) {
166 let block = self.block.take().unwrap();
167
168 {
169 let mut list = self.list.write().unwrap();
170 let i = block.index.load(Ordering::Relaxed);
171 let removed = list.blocked.swap_remove(i);
172 debug_assert!(Arc::ptr_eq(&removed, &block));
173 if let Some(block) = list.blocked.get(i) {
174 block.index.store(i, Ordering::Relaxed);
175 }
176 }
177
178 let block = Arc::into_inner(block).unwrap();
179 let waitlist = block.waitlist.into_inner().unwrap();
180 for waiting in waitlist {
181 waiting.send(()).unwrap();
182 }
183 }
184}