imago/storage/
ext.rs

1//! Provides the `StorageExt` struct for more convenient access.
2//!
3//! `Storage` is provided by the driver, so is supposed to be simple and only contain what’s
4//! necessary.  `StorageExt` builds on that to provide more convenient access, e.g. allows
5//! unaligned requests and provides write serialization.
6
7use super::drivers::RangeBlockedGuard;
8use crate::io_buffers::{IoBuffer, IoVector, IoVectorMut, IoVectorTrait};
9use crate::Storage;
10use std::ops::Range;
11use std::{cmp, io};
12use tracing::trace;
13
14/// Helper methods for storage objects.
15///
16/// Provides some more convenient methods for accessing storage objects.
17pub trait StorageExt: Storage {
18    /// Read data at `offset` into `bufv`.
19    ///
20    /// Reads until `bufv` is filled completely, i.e. will not do short reads.  When reaching the
21    /// end of file, the rest of `bufv` is filled with 0.
22    ///
23    /// Checks alignment.  If anything does not meet the requirements, enforces it (using ephemeral
24    /// bounce buffers).
25    #[allow(async_fn_in_trait)] // No need for Send
26    async fn readv(&self, bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()>;
27
28    /// Write data from `bufv` to `offset`.
29    ///
30    /// Writes all data from `bufv`, i.e. will not do short writes.  When reaching the end of file,
31    /// it is grown as necessary so that the new end of file will be at `offset + bufv.len()`.
32    ///
33    /// If growing is not possible, expect writes beyond the end of file (even if only partially)
34    /// to fail.
35    ///
36    /// Checks alignment.  If anything does not meet the requirements, enforces it using bounce
37    /// buffers and a read-modify-write cycle that blocks concurrent writes to the affected area.
38    #[allow(async_fn_in_trait)] // No need for Send
39    async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()>;
40
41    /// Read data at `offset` into `buf`.
42    ///
43    /// Reads until `buf` is filled completely, i.e. will not do short reads.  When reaching the
44    /// end of file, the rest of `buf` is filled with 0.
45    ///
46    /// Checks alignment.  If anything does not meet the requirements, enforces it (using ephemeral
47    /// bounce buffers).
48    #[allow(async_fn_in_trait)] // No need for Send
49    async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()>;
50
51    /// Write data from `buf` to `offset`.
52    ///
53    /// Writes all data from `buf`, i.e. will not do short writes.  When reaching the end of file,
54    /// it is grown as necessary so that the new end of file will be at `offset + buf.len()`.
55    ///
56    /// If growing is not possible, expect writes beyond the end of file (even if only partially)
57    /// to fail.
58    ///
59    /// Checks alignment.  If anything does not meet the requirements, enforces it using bounce
60    /// buffers and a read-modify-write cycle that blocks concurrent writes to the affected area.
61    #[allow(async_fn_in_trait)] // No need for Send
62    async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()>;
63
64    /// Ensure the given range reads back as zeroes.
65    #[allow(async_fn_in_trait)] // No need for Send
66    async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()>;
67
68    /// Discard the given range, with undefined contents when read back.
69    ///
70    /// Tell the storage layer this range is no longer needed and need not be backed by actual
71    /// storage.  When read back, the data read will be undefined, i.e. not necessarily zeroes.
72    #[allow(async_fn_in_trait)] // No need for Send
73    async fn discard(&self, offset: u64, length: u64) -> io::Result<()>;
74
75    /// Await concurrent strong write blockers for the given range.
76    ///
77    /// Strong write blockers are set up for writes that must not be intersected by any other
78    /// write.  Await such intersecting concurrent write requests, and return a guard that will
79    /// delay such new writes until the guard is dropped.
80    #[allow(async_fn_in_trait)] // No need for Send
81    async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
82
83    /// Await any concurrent write request for the given range.
84    ///
85    /// Block the given range for any concurrent write requests until the returned guard object is
86    /// dropped.  Existing requests are awaited, and new ones will be delayed.
87    #[allow(async_fn_in_trait)] // No need for Send
88    async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
89}
90
91impl<S: Storage> StorageExt for S {
92    async fn readv(&self, mut bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()> {
93        if bufv.is_empty() {
94            return Ok(());
95        }
96
97        let mem_align = self.mem_align();
98        let req_align = self.req_align();
99
100        if is_aligned(&bufv, offset, mem_align, req_align) {
101            // Safe: Alignment checked
102            return unsafe { self.pure_readv(bufv, offset) }.await;
103        }
104
105        trace!(
106            "Unaligned read: 0x{offset:x} + {} (size: {:#x})",
107            bufv.len(),
108            self.size().unwrap()
109        );
110
111        let req_align_mask = req_align as u64 - 1;
112        // Length must be aligned to both memory and request alignments
113        let len_align_mask = req_align_mask | (mem_align as u64 - 1);
114        debug_assert!((len_align_mask + 1).is_multiple_of(req_align as u64));
115
116        let unpadded_end = offset + bufv.len();
117        let padded_offset = offset & !req_align_mask;
118        // This will over-align at the end of file (aligning to exactly the end of file would be
119        // sufficient), but it is easier this way.
120        let padded_end = (unpadded_end + req_align_mask) & !req_align_mask;
121        // Now also align to memory alignment
122        let padded_len = (padded_end - padded_offset + len_align_mask) & !(len_align_mask);
123        let padded_end = padded_offset + padded_len;
124
125        let padded_len: usize = (padded_end - padded_offset)
126            .try_into()
127            .map_err(|e| io::Error::other(format!("Cannot realign read: {e}")))?;
128
129        trace!("Padded read: {padded_offset:#x} + {padded_len}");
130
131        let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
132
133        // Safe: Alignment enforced
134        unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
135
136        let in_buf_ofs = (offset - padded_offset) as usize;
137        // Must fit in `usize` because `padded_len: usize`
138        let in_buf_end = (unpadded_end - padded_offset) as usize;
139
140        bufv.copy_from_slice(bounce_buf.as_ref_range(in_buf_ofs..in_buf_end).into_slice());
141
142        Ok(())
143    }
144
145    async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()> {
146        if bufv.is_empty() {
147            return Ok(());
148        }
149
150        let mem_align = self.mem_align();
151        let req_align = self.req_align();
152
153        if is_aligned(&bufv, offset, mem_align, req_align) {
154            let _sw_guard = self.weak_write_blocker(offset..(offset + bufv.len())).await;
155
156            // Safe: Alignment checked, and weak write blocker set up
157            return unsafe { self.pure_writev(bufv, offset) }.await;
158        }
159
160        trace!(
161            "Unaligned write: {offset:#x} + {} (size: {:#x})",
162            bufv.len(),
163            self.size().unwrap()
164        );
165
166        let req_align_mask = req_align - 1;
167        // Length must be aligned to both memory and request alignments
168        let len_align_mask = req_align_mask | (mem_align - 1);
169        let len_align = req_align_mask + 1;
170        debug_assert!(len_align.is_multiple_of(req_align));
171
172        let unpadded_end = offset + bufv.len();
173        let padded_offset = offset & !(req_align_mask as u64);
174        // This will over-align at the end of file (aligning to exactly the end of file would be
175        // sufficient), but it is easier this way.  Small TODO, as this will indeed increase the
176        // file length (which the over-alignment in `unaligned_readv()` does not).
177        let padded_end = (unpadded_end + req_align_mask as u64) & !(req_align_mask as u64);
178        // Now also align to memory alignment
179        let padded_len =
180            (padded_end - padded_offset + len_align_mask as u64) & !(len_align_mask as u64);
181        let padded_end = padded_offset + padded_len;
182
183        let padded_len: usize = (padded_end - padded_offset)
184            .try_into()
185            .map_err(|e| io::Error::other(format!("Cannot realign write: {e}")))?;
186
187        trace!("Padded write: {padded_offset:#x} + {padded_len}");
188
189        let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
190        assert!(padded_len >= len_align && padded_len & len_align_mask == 0);
191
192        // For the strong blocker, just the RMW regions (head and tail) would be enough.  However,
193        // we don’t expect any concurrent writes to the non-RMW (pure write) regions (it is
194        // unlikely that the guest would write to the same area twice concurrently), so we don’t
195        // need to optimize for it.  On the other hand, writes to the RMW regions are likely
196        // (adjacent writes), so those will be blocked either way.
197        // Instating fewer blockers makes them less expensive to check, though.
198        let _sw_guard = self.strong_write_blocker(padded_offset..padded_end).await;
199
200        let in_buf_ofs = (offset - padded_offset) as usize;
201        // Must fit in `usize` because `padded_len: usize`
202        let in_buf_end = (unpadded_end - padded_offset) as usize;
203
204        // RMW part 1: Read
205
206        let head_len = in_buf_ofs;
207        let aligned_head_len = (head_len + len_align_mask) & !len_align_mask;
208
209        let tail_len = padded_len - in_buf_end;
210        let aligned_tail_len = (tail_len + len_align_mask) & !len_align_mask;
211
212        if aligned_head_len + aligned_tail_len == padded_len {
213            // Must read the whole bounce buffer
214            // Safe: Alignment enforced
215            unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
216        } else {
217            if aligned_head_len > 0 {
218                let head_bufv = bounce_buf.as_mut_range(0..aligned_head_len).into();
219                // Safe: Alignment enforced
220                unsafe { self.pure_readv(head_bufv, padded_offset) }.await?;
221            }
222            if aligned_tail_len > 0 {
223                let tail_start = padded_len - aligned_tail_len;
224                let tail_bufv = bounce_buf.as_mut_range(tail_start..padded_len).into();
225                // Safe: Alignment enforced
226                unsafe { self.pure_readv(tail_bufv, padded_offset + tail_start as u64) }.await?;
227            }
228        }
229
230        // RMW part 2: Modify
231        bufv.copy_into_slice(bounce_buf.as_mut_range(in_buf_ofs..in_buf_end).into_slice());
232
233        // RMW part 3: Write
234        // Safe: Alignment enforced, and strong write blocker set up
235        unsafe { self.pure_writev(bounce_buf.as_ref().into(), padded_offset) }.await
236    }
237
238    async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()> {
239        self.readv(buf.into(), offset).await
240    }
241
242    async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()> {
243        self.writev(buf.into(), offset).await
244    }
245
246    async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()> {
247        let zero_align = self.zero_align();
248        debug_assert!(zero_align.is_power_of_two());
249        let align_mask = zero_align as u64 - 1;
250
251        let unaligned_end = offset
252            .checked_add(length)
253            .ok_or_else(|| io::Error::other("Zero-write wrap-around"))?;
254        let aligned_offset = (offset + align_mask) & !align_mask;
255        let aligned_end = unaligned_end & !align_mask;
256
257        if aligned_end > aligned_offset {
258            let _sw_guard = self.weak_write_blocker(aligned_offset..aligned_end).await;
259            // Safe: Alignment checked, and weak write blocker set up
260            unsafe { self.pure_write_zeroes(aligned_offset, aligned_end - aligned_offset) }.await?;
261        }
262
263        let zero_buf = if aligned_offset > offset || aligned_end < unaligned_end {
264            let mut buf = IoBuffer::new(
265                cmp::max(aligned_offset - offset, unaligned_end - aligned_end) as usize,
266                self.mem_align(),
267            )?;
268            buf.as_mut().into_slice().fill(0);
269            Some(buf)
270        } else {
271            None
272        };
273
274        if aligned_offset > offset {
275            let buf = zero_buf
276                .as_ref()
277                .unwrap()
278                .as_ref_range(0..((aligned_offset - offset) as usize));
279            self.write(buf, offset).await?;
280        }
281        if aligned_end < unaligned_end {
282            let buf = zero_buf
283                .as_ref()
284                .unwrap()
285                .as_ref_range(0..((unaligned_end - aligned_end) as usize));
286            self.write(buf, aligned_end).await?;
287        }
288
289        Ok(())
290    }
291
292    async fn discard(&self, offset: u64, length: u64) -> io::Result<()> {
293        let discard_align = self.discard_align();
294        debug_assert!(discard_align.is_power_of_two());
295        let align_mask = discard_align as u64 - 1;
296
297        let unaligned_end = offset
298            .checked_add(length)
299            .ok_or_else(|| io::Error::other("Discard wrap-around"))?;
300        let aligned_offset = (offset + align_mask) & !align_mask;
301        let aligned_end = unaligned_end & !align_mask;
302
303        if aligned_end > aligned_offset {
304            let _sw_guard = self.weak_write_blocker(offset..(offset + length)).await;
305            // Safe: Alignment checked, and weak write blocker set up
306            unsafe { self.pure_discard(offset, length) }.await?;
307        }
308
309        // Nothing to do for the unaligned part; discarding is always just advisory.
310
311        Ok(())
312    }
313
314    async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
315        self.get_storage_helper().weak_write_blocker(range).await
316    }
317
318    async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
319        self.get_storage_helper().strong_write_blocker(range).await
320    }
321}
322
323/// Check whether the given request is aligned.
324fn is_aligned<V: IoVectorTrait>(bufv: &V, offset: u64, mem_align: usize, req_align: usize) -> bool {
325    debug_assert!(mem_align.is_power_of_two() && req_align.is_power_of_two());
326
327    let req_align_mask = req_align as u64 - 1;
328
329    if offset & req_align_mask != 0 {
330        false
331    } else if bufv.len() & req_align_mask == 0 {
332        bufv.is_aligned(mem_align, req_align)
333    } else {
334        false
335    }
336}
337
338/// Write zero data to the given area.
339///
340/// In contrast to `write_zeroes()` functions, this one will actually write zero data, fully
341/// allocated.
342pub(crate) async fn write_full_zeroes<S: StorageExt>(
343    storage: S,
344    mut offset: u64,
345    mut length: u64,
346) -> io::Result<()> {
347    let buflen = cmp::min(length, 1048576) as usize;
348    let mut buf = IoBuffer::new(buflen, storage.mem_align())?;
349    buf.as_mut().into_slice().fill(0);
350
351    let req_align = storage.req_align();
352    let req_align_mask = (req_align - 1) as u64;
353
354    while length > 0 {
355        let mut chunk_length = cmp::min(length, 1048576) as usize;
356        if offset & req_align_mask != 0 {
357            chunk_length = cmp::min(chunk_length, req_align - (offset & req_align_mask) as usize);
358        }
359        storage
360            .write(buf.as_ref_range(0..chunk_length), offset)
361            .await?;
362        offset += chunk_length as u64;
363        length -= chunk_length as u64;
364    }
365
366    Ok(())
367}