1use super::drivers::RangeBlockedGuard;
8use crate::io_buffers::{IoBuffer, IoVector, IoVectorMut, IoVectorTrait};
9use crate::Storage;
10use std::ops::Range;
11use std::{cmp, io};
12use tracing::trace;
13
14pub trait StorageExt: Storage {
18 #[allow(async_fn_in_trait)] async fn readv(&self, bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()>;
27
28 #[allow(async_fn_in_trait)] async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()>;
40
41 #[allow(async_fn_in_trait)] async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()>;
50
51 #[allow(async_fn_in_trait)] async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()>;
63
64 #[allow(async_fn_in_trait)] async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()>;
67
68 #[allow(async_fn_in_trait)] async fn discard(&self, offset: u64, length: u64) -> io::Result<()>;
74
75 #[allow(async_fn_in_trait)] async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
82
83 #[allow(async_fn_in_trait)] async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_>;
89}
90
91impl<S: Storage> StorageExt for S {
92 async fn readv(&self, mut bufv: IoVectorMut<'_>, offset: u64) -> io::Result<()> {
93 if bufv.is_empty() {
94 return Ok(());
95 }
96
97 let mem_align = self.mem_align();
98 let req_align = self.req_align();
99
100 if is_aligned(&bufv, offset, mem_align, req_align) {
101 return unsafe { self.pure_readv(bufv, offset) }.await;
103 }
104
105 trace!(
106 "Unaligned read: 0x{offset:x} + {} (size: {:#x})",
107 bufv.len(),
108 self.size().unwrap()
109 );
110
111 let req_align_mask = req_align as u64 - 1;
112 let len_align_mask = req_align_mask | (mem_align as u64 - 1);
114 debug_assert!((len_align_mask + 1).is_multiple_of(req_align as u64));
115
116 let unpadded_end = offset + bufv.len();
117 let padded_offset = offset & !req_align_mask;
118 let padded_end = (unpadded_end + req_align_mask) & !req_align_mask;
121 let padded_len = (padded_end - padded_offset + len_align_mask) & !(len_align_mask);
123 let padded_end = padded_offset + padded_len;
124
125 let padded_len: usize = (padded_end - padded_offset)
126 .try_into()
127 .map_err(|e| io::Error::other(format!("Cannot realign read: {e}")))?;
128
129 trace!("Padded read: {padded_offset:#x} + {padded_len}");
130
131 let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
132
133 unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
135
136 let in_buf_ofs = (offset - padded_offset) as usize;
137 let in_buf_end = (unpadded_end - padded_offset) as usize;
139
140 bufv.copy_from_slice(bounce_buf.as_ref_range(in_buf_ofs..in_buf_end).into_slice());
141
142 Ok(())
143 }
144
145 async fn writev(&self, bufv: IoVector<'_>, offset: u64) -> io::Result<()> {
146 if bufv.is_empty() {
147 return Ok(());
148 }
149
150 let mem_align = self.mem_align();
151 let req_align = self.req_align();
152
153 if is_aligned(&bufv, offset, mem_align, req_align) {
154 let _sw_guard = self.weak_write_blocker(offset..(offset + bufv.len())).await;
155
156 return unsafe { self.pure_writev(bufv, offset) }.await;
158 }
159
160 trace!(
161 "Unaligned write: {offset:#x} + {} (size: {:#x})",
162 bufv.len(),
163 self.size().unwrap()
164 );
165
166 let req_align_mask = req_align - 1;
167 let len_align_mask = req_align_mask | (mem_align - 1);
169 let len_align = req_align_mask + 1;
170 debug_assert!(len_align.is_multiple_of(req_align));
171
172 let unpadded_end = offset + bufv.len();
173 let padded_offset = offset & !(req_align_mask as u64);
174 let padded_end = (unpadded_end + req_align_mask as u64) & !(req_align_mask as u64);
178 let padded_len =
180 (padded_end - padded_offset + len_align_mask as u64) & !(len_align_mask as u64);
181 let padded_end = padded_offset + padded_len;
182
183 let padded_len: usize = (padded_end - padded_offset)
184 .try_into()
185 .map_err(|e| io::Error::other(format!("Cannot realign write: {e}")))?;
186
187 trace!("Padded write: {padded_offset:#x} + {padded_len}");
188
189 let mut bounce_buf = IoBuffer::new(padded_len, mem_align)?;
190 assert!(padded_len >= len_align && padded_len & len_align_mask == 0);
191
192 let _sw_guard = self.strong_write_blocker(padded_offset..padded_end).await;
199
200 let in_buf_ofs = (offset - padded_offset) as usize;
201 let in_buf_end = (unpadded_end - padded_offset) as usize;
203
204 let head_len = in_buf_ofs;
207 let aligned_head_len = (head_len + len_align_mask) & !len_align_mask;
208
209 let tail_len = padded_len - in_buf_end;
210 let aligned_tail_len = (tail_len + len_align_mask) & !len_align_mask;
211
212 if aligned_head_len + aligned_tail_len == padded_len {
213 unsafe { self.pure_readv(bounce_buf.as_mut().into(), padded_offset) }.await?;
216 } else {
217 if aligned_head_len > 0 {
218 let head_bufv = bounce_buf.as_mut_range(0..aligned_head_len).into();
219 unsafe { self.pure_readv(head_bufv, padded_offset) }.await?;
221 }
222 if aligned_tail_len > 0 {
223 let tail_start = padded_len - aligned_tail_len;
224 let tail_bufv = bounce_buf.as_mut_range(tail_start..padded_len).into();
225 unsafe { self.pure_readv(tail_bufv, padded_offset + tail_start as u64) }.await?;
227 }
228 }
229
230 bufv.copy_into_slice(bounce_buf.as_mut_range(in_buf_ofs..in_buf_end).into_slice());
232
233 unsafe { self.pure_writev(bounce_buf.as_ref().into(), padded_offset) }.await
236 }
237
238 async fn read(&self, buf: impl Into<IoVectorMut<'_>>, offset: u64) -> io::Result<()> {
239 self.readv(buf.into(), offset).await
240 }
241
242 async fn write(&self, buf: impl Into<IoVector<'_>>, offset: u64) -> io::Result<()> {
243 self.writev(buf.into(), offset).await
244 }
245
246 async fn write_zeroes(&self, offset: u64, length: u64) -> io::Result<()> {
247 let zero_align = self.zero_align();
248 debug_assert!(zero_align.is_power_of_two());
249 let align_mask = zero_align as u64 - 1;
250
251 let unaligned_end = offset
252 .checked_add(length)
253 .ok_or_else(|| io::Error::other("Zero-write wrap-around"))?;
254 let aligned_offset = (offset + align_mask) & !align_mask;
255 let aligned_end = unaligned_end & !align_mask;
256
257 if aligned_end > aligned_offset {
258 let _sw_guard = self.weak_write_blocker(aligned_offset..aligned_end).await;
259 unsafe { self.pure_write_zeroes(aligned_offset, aligned_end - aligned_offset) }.await?;
261 }
262
263 let zero_buf = if aligned_offset > offset || aligned_end < unaligned_end {
264 let mut buf = IoBuffer::new(
265 cmp::max(aligned_offset - offset, unaligned_end - aligned_end) as usize,
266 self.mem_align(),
267 )?;
268 buf.as_mut().into_slice().fill(0);
269 Some(buf)
270 } else {
271 None
272 };
273
274 if aligned_offset > offset {
275 let buf = zero_buf
276 .as_ref()
277 .unwrap()
278 .as_ref_range(0..((aligned_offset - offset) as usize));
279 self.write(buf, offset).await?;
280 }
281 if aligned_end < unaligned_end {
282 let buf = zero_buf
283 .as_ref()
284 .unwrap()
285 .as_ref_range(0..((unaligned_end - aligned_end) as usize));
286 self.write(buf, aligned_end).await?;
287 }
288
289 Ok(())
290 }
291
292 async fn discard(&self, offset: u64, length: u64) -> io::Result<()> {
293 let discard_align = self.discard_align();
294 debug_assert!(discard_align.is_power_of_two());
295 let align_mask = discard_align as u64 - 1;
296
297 let unaligned_end = offset
298 .checked_add(length)
299 .ok_or_else(|| io::Error::other("Discard wrap-around"))?;
300 let aligned_offset = (offset + align_mask) & !align_mask;
301 let aligned_end = unaligned_end & !align_mask;
302
303 if aligned_end > aligned_offset {
304 let _sw_guard = self.weak_write_blocker(offset..(offset + length)).await;
305 unsafe { self.pure_discard(offset, length) }.await?;
307 }
308
309 Ok(())
312 }
313
314 async fn weak_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
315 self.get_storage_helper().weak_write_blocker(range).await
316 }
317
318 async fn strong_write_blocker(&self, range: Range<u64>) -> RangeBlockedGuard<'_> {
319 self.get_storage_helper().strong_write_blocker(range).await
320 }
321}
322
323fn is_aligned<V: IoVectorTrait>(bufv: &V, offset: u64, mem_align: usize, req_align: usize) -> bool {
325 debug_assert!(mem_align.is_power_of_two() && req_align.is_power_of_two());
326
327 let req_align_mask = req_align as u64 - 1;
328
329 if offset & req_align_mask != 0 {
330 false
331 } else if bufv.len() & req_align_mask == 0 {
332 bufv.is_aligned(mem_align, req_align)
333 } else {
334 false
335 }
336}
337
338pub(crate) async fn write_full_zeroes<S: StorageExt>(
343 storage: S,
344 mut offset: u64,
345 mut length: u64,
346) -> io::Result<()> {
347 let buflen = cmp::min(length, 1048576) as usize;
348 let mut buf = IoBuffer::new(buflen, storage.mem_align())?;
349 buf.as_mut().into_slice().fill(0);
350
351 let req_align = storage.req_align();
352 let req_align_mask = (req_align - 1) as u64;
353
354 while length > 0 {
355 let mut chunk_length = cmp::min(length, 1048576) as usize;
356 if offset & req_align_mask != 0 {
357 chunk_length = cmp::min(chunk_length, req_align - (offset & req_align_mask) as usize);
358 }
359 storage
360 .write(buf.as_ref_range(0..chunk_length), offset)
361 .await?;
362 offset += chunk_length as u64;
363 length -= chunk_length as u64;
364 }
365
366 Ok(())
367}