imago/storage/
ext.rs

1//! Provides the `StorageExt` struct for more convenient access.
2//!
3//! `Storage` is provided by the driver, so is supposed to be simple and only contain what’s
4//! necessary.  `StorageExt` builds on that to provide more convenient access, e.g. allows
5//! unaligned requests and provides write serialization.
6
7use super::drivers::RangeBlockedGuard;
8use crate::io_buffers::{IoBuffer, IoVector, IoVectorMut, IoVectorTrait};
9use crate::Storage;
10use std::ops::Range;
11use std::{cmp, io};
12use tracing::trace;
13
14/// Helper methods for storage objects.
15///
16/// Provides some more convenient methods for accessing storage objects.
17pub trait StorageExt: Storage {
18    /// Read data at `offset` into `bufv`.
19    ///
20    /// Reads until `bufv` is filled completely, i.e. will not do short reads.  When reaching the
21    /// end of file, the rest of `bufv` is filled with 0.
22    ///
23    /// Checks alignment.  If anything does not meet the requirements, enforces it (using ephemeral
24    /// bounce buffers).
25    #[allow(async_fn_in_trait)] // No need for Send
26    async fn readv(&self, bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()>;
27
28    /// Write data from `bufv` to `offset`.
29    ///
30    /// Writes all data from `bufv`, i.e. will not do short writes.  When reaching the end of file,
31    /// it is grown as necessary so that the new end of file will be at `offset + bufv.len()`.
32    ///
33    /// If growing is not possible, expect writes beyond the end of file (even if only partially)
34    /// to fail.
35    ///
36    /// Checks alignment.  If anything does not meet the requirements, enforces it using bounce
37    /// buffers and a read-modify-write cycle that blocks concurrent writes to the affected area.
38    #[allow(async_fn_in_trait)] // No need for Send
39    async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()>;
40
41    /// Read data at `offset` into `buf`.
42    ///
43    /// Reads until `buf` is filled completely, i.e. will not do short reads.  When reaching the
44    /// end of file, the rest of `buf` is filled with 0.
45    ///
46    /// Checks alignment.  If anything does not meet the requirements, enforces it (using ephemeral
47    /// bounce buffers).
48    #[allow(async_fn_in_trait)] // No need for Send
49    async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()>;
50
51    /// Write data from `buf` to `offset`.
52    ///
53    /// Writes all data from `buf`, i.e. will not do short writes.  When reaching the end of file,
54    /// it is grown as necessary so that the new end of file will be at `offset + buf.len()`.
55    ///
56    /// If growing is not possible, expect writes beyond the end of file (even if only partially)
57    /// to fail.
58    ///
59    /// Checks alignment.  If anything does not meet the requirements, enforces it using bounce
60    /// buffers and a read-modify-write cycle that blocks concurrent writes to the affected area.
61    #[allow(async_fn_in_trait)] // No need for Send
62    async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()>;
63
64    /// Ensure the given range reads back as zeroes.
65    #[allow(async_fn_in_trait)] // No need for Send
66    async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()>;
67
68    /// Ensure the given range is allocated and reads back as zeroes.
69    #[allow(async_fn_in_trait)] // No need for Send
70    async fn write_allocated_zeroes(&self, offset: u64, length: u64) -> io::Result<()>;
71
72    /// Discard the given range, with undefined contents when read back.
73    ///
74    /// Tell the storage layer this range is no longer needed and need not be backed by actual
75    /// storage.  When read back, the data read will be undefined, i.e. not necessarily zeroes.
76    #[allow(async_fn_in_trait)] // No need for Send
77    async fn discard(&self, offset: u64, length: u64) -> io::Result<()>;
78
79    /// Await concurrent strong write blockers for the given range.
80    ///
81    /// Strong write blockers are set up for writes that must not be intersected by any other
82    /// write.  Await such intersecting concurrent write requests, and return a guard that will
83    /// delay such new writes until the guard is dropped.
84    #[allow(async_fn_in_trait)] // No need for Send
85    async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
86
87    /// Await any concurrent write request for the given range.
88    ///
89    /// Block the given range for any concurrent write requests until the returned guard object is
90    /// dropped.  Existing requests are awaited, and new ones will be delayed.
91    #[allow(async_fn_in_trait)] // No need for Send
92    async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
93}
94
95impl<S: Storage> StorageExt for S {
96    async fn readv(&self, mut bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()> {
97        if bufv.is_empty() {
98            return Ok(());
99        }
100
101        let mem_align = self.mem_align();
102        let req_align = self.req_align();
103
104        if is_aligned(&bufv, offset, mem_align, req_align) {
105            // Safe: Alignment checked
106            return unsafe { self.pure_readv(bufv, offset) }.await;
107        }
108
109        trace!(
110            "Unaligned read: 0x{offset:x} + {} (size: {:#x})",
111            bufv.len(),
112            self.size().unwrap()
113        );
114
115        let req_align_mask = req_align as u64 - 1;
116        // Length must be aligned to both memory and request alignments
117        let len_align_mask = req_align_mask | (mem_align as u64 - 1);
118        debug_assert!((len_align_mask + 1).is_multiple_of(req_align as u64));
119
120        let unpadded_end = offset + bufv.len();
121        let padded_offset = offset & !req_align_mask;
122        // This will over-align at the end of file (aligning to exactly the end of file would be
123        // sufficient), but it is easier this way.
124        let padded_end = (unpadded_end + req_align_mask) & !req_align_mask;
125        // Now also align to memory alignment
126        let padded_len = (padded_end - padded_offset + len_align_mask) & !(len_align_mask);
127        let padded_end = padded_offset + padded_len;
128
129        let padded_len: usize = (padded_end - padded_offset)
130            .try_into()
131            .map_err(|e| io::Error::other(format!("Cannot realign read: {e}")))?;
132
133        trace!("Padded read: {padded_offset:#x} + {padded_len}");
134
135        let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
136
137        // Safe: Alignment enforced
138        unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
139
140        let in_buf_ofs = (offset - padded_offset) as usize;
141        // Must fit in `usize` because `padded_len: usize`
142        let in_buf_end = (unpadded_end - padded_offset) as usize;
143
144        bufv.copy_from_slice(bounce_buf.as_ref_range(in_buf_ofs..in_buf_end).into_slice());
145
146        Ok(())
147    }
148
149    async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()> {
150        if bufv.is_empty() {
151            return Ok(());
152        }
153
154        let mem_align = self.mem_align();
155        let req_align = self.req_align();
156
157        if is_aligned(&bufv, offset, mem_align, req_align) {
158            let _sw_guard = self.weak_write_blocker(offset..(offset + bufv.len())).await;
159
160            // Safe: Alignment checked, and weak write blocker set up
161            return unsafe { self.pure_writev(bufv, offset) }.await;
162        }
163
164        trace!(
165            "Unaligned write: {offset:#x} + {} (size: {:#x})",
166            bufv.len(),
167            self.size().unwrap()
168        );
169
170        let req_align_mask = req_align - 1;
171        // Length must be aligned to both memory and request alignments
172        let len_align_mask = req_align_mask | (mem_align - 1);
173        let len_align = req_align_mask + 1;
174        debug_assert!(len_align.is_multiple_of(req_align));
175
176        let unpadded_end = offset + bufv.len();
177        let padded_offset = offset & !(req_align_mask as u64);
178        // This will over-align at the end of file (aligning to exactly the end of file would be
179        // sufficient), but it is easier this way.  Small TODO, as this will indeed increase the
180        // file length (which the over-alignment in `unaligned_readv()` does not).
181        let padded_end = (unpadded_end + req_align_mask as u64) & !(req_align_mask as u64);
182        // Now also align to memory alignment
183        let padded_len =
184            (padded_end - padded_offset + len_align_mask as u64) & !(len_align_mask as u64);
185        let padded_end = padded_offset + padded_len;
186
187        let padded_len: usize = (padded_end - padded_offset)
188            .try_into()
189            .map_err(|e| io::Error::other(format!("Cannot realign write: {e}")))?;
190
191        trace!("Padded write: {padded_offset:#x} + {padded_len}");
192
193        let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
194        assert!(padded_len >= len_align && padded_len & len_align_mask == 0);
195
196        // For the strong blocker, just the RMW regions (head and tail) would be enough.  However,
197        // we don’t expect any concurrent writes to the non-RMW (pure write) regions (it is
198        // unlikely that the guest would write to the same area twice concurrently), so we don’t
199        // need to optimize for it.  On the other hand, writes to the RMW regions are likely
200        // (adjacent writes), so those will be blocked either way.
201        // Instating fewer blockers makes them less expensive to check, though.
202        let _sw_guard = self.strong_write_blocker(padded_offset..padded_end).await;
203
204        let in_buf_ofs = (offset - padded_offset) as usize;
205        // Must fit in `usize` because `padded_len: usize`
206        let in_buf_end = (unpadded_end - padded_offset) as usize;
207
208        // RMW part 1: Read
209
210        let head_len = in_buf_ofs;
211        let aligned_head_len = (head_len + len_align_mask) & !len_align_mask;
212
213        let tail_len = padded_len - in_buf_end;
214        let aligned_tail_len = (tail_len + len_align_mask) & !len_align_mask;
215
216        if aligned_head_len + aligned_tail_len == padded_len {
217            // Must read the whole bounce buffer
218            // Safe: Alignment enforced
219            unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
220        } else {
221            if aligned_head_len > 0 {
222                let head_bufv = bounce_buf.as_mut_range(0..aligned_head_len).into();
223                // Safe: Alignment enforced
224                unsafe { self.pure_readv(head_bufv, padded_offset) }.await?;
225            }
226            if aligned_tail_len > 0 {
227                let tail_start = padded_len - aligned_tail_len;
228                let tail_bufv = bounce_buf.as_mut_range(tail_start..padded_len).into();
229                // Safe: Alignment enforced
230                unsafe { self.pure_readv(tail_bufv, padded_offset + tail_start as u64) }.await?;
231            }
232        }
233
234        // RMW part 2: Modify
235        bufv.copy_into_slice(bounce_buf.as_mut_range(in_buf_ofs..in_buf_end).into_slice());
236
237        // RMW part 3: Write
238        // Safe: Alignment enforced, and strong write blocker set up
239        unsafe { self.pure_writev(bounce_buf.as_ref().into(), padded_offset) }.await
240    }
241
242    async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()> {
243        self.readv(buf.into(), offset).await
244    }
245
246    async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()> {
247        self.writev(buf.into(), offset).await
248    }
249
250    async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()> {
251        write_efficient_zeroes(self, offset, length, false).await
252    }
253
254    async fn write_allocated_zeroes(&self, offset: u64, length: u64) -> io::Result<()> {
255        write_efficient_zeroes(self, offset, length, true).await
256    }
257
258    async fn discard(&self, offset: u64, length: u64) -> io::Result<()> {
259        let discard_align = self.discard_align();
260        debug_assert!(discard_align.is_power_of_two());
261        let align_mask = discard_align as u64 - 1;
262
263        let unaligned_end = offset
264            .checked_add(length)
265            .ok_or_else(|| io::Error::other("Discard wrap-around"))?;
266        let aligned_offset = (offset + align_mask) & !align_mask;
267        let aligned_end = unaligned_end & !align_mask;
268
269        if aligned_end > aligned_offset {
270            let _sw_guard = self.weak_write_blocker(offset..(offset + length)).await;
271            // Safe: Alignment checked, and weak write blocker set up
272            unsafe { self.pure_discard(offset, length) }.await?;
273        }
274
275        // Nothing to do for the unaligned part; discarding is always just advisory.
276
277        Ok(())
278    }
279
280    async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
281        self.get_storage_helper().weak_write_blocker(range).await
282    }
283
284    async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
285        self.get_storage_helper().strong_write_blocker(range).await
286    }
287}
288
289/// Check whether the given request is aligned.
290fn is_aligned<V: IoVectorTrait>(bufv: &V, offset: u64, mem_align: usize, req_align: usize) -> bool {
291    debug_assert!(mem_align.is_power_of_two() && req_align.is_power_of_two());
292
293    let req_align_mask = req_align as u64 - 1;
294
295    if offset & req_align_mask != 0 {
296        false
297    } else if bufv.len() & req_align_mask == 0 {
298        bufv.is_aligned(mem_align, req_align)
299    } else {
300        false
301    }
302}
303
304/// Write zero data to the given area.
305///
306/// In contrast to `write_zeroes()` functions, this one will actually write zero data, fully
307/// allocated.
308pub(crate) async fn write_full_zeroes<S: StorageExt>(
309    storage: S,
310    mut offset: u64,
311    mut length: u64,
312) -> io::Result<()> {
313    let buflen = cmp::min(length, 1048576) as usize;
314    let mut buf = IoBuffer::new(buflen, storage.mem_align())?;
315    buf.as_mut().into_slice().fill(0);
316
317    let req_align = storage.req_align();
318    let req_align_mask = (req_align - 1) as u64;
319
320    while length > 0 {
321        let mut chunk_length = cmp::min(length, 1048576) as usize;
322        if offset & req_align_mask != 0 {
323            chunk_length = cmp::min(chunk_length, req_align - (offset & req_align_mask) as usize);
324        }
325        storage
326            .write(buf.as_ref_range(0..chunk_length), offset)
327            .await?;
328        offset += chunk_length as u64;
329        length -= chunk_length as u64;
330    }
331
332    Ok(())
333}
334
335/// Write zeroes efficiently to the given area.
336///
337/// This implements `write_zeroes()` and `write_allocated_zeroes()`.
338///
339/// If `allocate` is `true`, use [`Storage::pure_write_allocated_zeroes()`]; else, use
340/// [`Storage::pure_write_zeroes()`].
341///
342/// If the `pure_*` call fails with [`io::ErrorKind::Unsupported`], fall back to
343/// [`write_full_zeroes()`].
344pub(crate) async fn write_efficient_zeroes<S: StorageExt>(
345    storage: S,
346    offset: u64,
347    length: u64,
348    allocate: bool,
349) -> io::Result<()> {
350    let zero_align = storage.zero_align();
351    debug_assert!(zero_align.is_power_of_two());
352    let align_mask = zero_align as u64 - 1;
353
354    let unaligned_end = offset
355        .checked_add(length)
356        .ok_or_else(|| io::Error::other("Zero-write wrap-around"))?;
357    let aligned_offset = (offset + align_mask) & !align_mask;
358    let aligned_end = unaligned_end & !align_mask;
359
360    if aligned_end > aligned_offset {
361        let result = {
362            let _sw_guard = storage
363                .weak_write_blocker(aligned_offset..aligned_end)
364                .await;
365            // Safe: Alignment checked, and weak write blocker set up
366            if allocate {
367                unsafe {
368                    storage
369                        .pure_write_allocated_zeroes(aligned_offset, aligned_end - aligned_offset)
370                }
371                .await
372            } else {
373                unsafe { storage.pure_write_zeroes(aligned_offset, aligned_end - aligned_offset) }
374                    .await
375            }
376        };
377        if let Err(err) = result {
378            return if err.kind() == io::ErrorKind::Unsupported {
379                write_full_zeroes(storage, offset, length).await
380            } else {
381                Err(err)
382            };
383        }
384    }
385
386    let zero_buf = if aligned_offset > offset || aligned_end < unaligned_end {
387        let mut buf = IoBuffer::new(
388            cmp::max(aligned_offset - offset, unaligned_end - aligned_end) as usize,
389            storage.mem_align(),
390        )?;
391        buf.as_mut().into_slice().fill(0);
392        Some(buf)
393    } else {
394        None
395    };
396
397    if aligned_offset > offset {
398        let buf = zero_buf
399            .as_ref()
400            .unwrap()
401            .as_ref_range(0..((aligned_offset - offset) as usize));
402        storage.write(buf, offset).await?;
403    }
404    if aligned_end < unaligned_end {
405        let buf = zero_buf
406            .as_ref()
407            .unwrap()
408            .as_ref_range(0..((unaligned_end - aligned_end) as usize));
409        storage.write(buf, aligned_end).await?;
410    }
411
412    Ok(())
413}