Skip to main content

ostd/sync/
wait.rs

1// SPDX-License-Identifier: MPL-2.0
2use vstd::atomic_ghost::*;
3use vstd::prelude::*;
4
5use alloc::{collections::VecDeque, sync::Arc};
6use core::sync::atomic::{AtomicBool, Ordering};
7
8use super::{LocalIrqDisabled, SpinLock};
9use crate::task::{scheduler, Task};
10
11// # Explanation on the memory orders
12//
13// ```
14// [CPU 1 (the waker)]     [CPU 2 (the waiter)]
15// cond = true;
16// wake_up();
17//                         wait();
18//                         if cond { /* .. */ }
19// ```
20//
21// As soon as the waiter is woken up by the waker, it must see the true condition. This is
22// trivially satisfied if `wake_up()` and `wait()` synchronize with a lock. But if they synchronize
23// with an atomic variable, `wake_up()` must access the variable with `Ordering::Release` and
24// `wait()` must access the variable with `Ordering::Acquire`.
25//
26// Examples of `wake_up()`:
27//  - `WaitQueue::wake_one()`
28//  - `WaitQueue::wake_all()`
29//  - `Waker::wake_up()`
30//
31// Examples of `wait()`:
32//  - `WaitQueue::wait_until()`
33//  - `Waiter::wait()`
34//  - `Waiter::drop()`
35//
36// Note that dropping a waiter must be treated as a `wait()` with zero timeout, because we need to
37// make sure that the wake event isn't lost in this case.
38
39verus! {
40
41pub tracked struct WaitQueueGhost {
42    pub queued_wakers: Seq<int>,
43}
44
45struct_with_invariants! {
46
47/// A wait queue.
48///
49/// One may wait on a wait queue to put its executing thread to sleep.
50/// Multiple threads may be the waiters of a wait queue.
51/// Other threads may invoke the `wake`-family methods of a wait queue to
52/// wake up one or many waiting threads.
53pub struct WaitQueue {
54    // A copy of `wakers.len()`, used for the lock-free fast path in `wake_one` and `wake_all`.
55    num_wakers: AtomicU32<_, WaitQueueGhost, _>,
56    wakers: SpinLock<VecDeque<Arc<Waker>>, LocalIrqDisabled>,
57}
58
59closed spec fn wf(self) -> bool {
60    invariant on num_wakers is (v: u32, g: WaitQueueGhost) {
61        &&& g.queued_wakers.len() == v as int
62    }
63}
64}
65
66impl WaitQueue {
67    #[verifier::type_invariant]
68    pub closed spec fn type_inv(self) -> bool {
69        self.wf()
70    }
71}
72
73impl WaitQueue {
74    /// Creates a new, empty wait queue.
75    #[verifier::external_body]
76    pub const fn new() -> Self {
77        WaitQueue {
78            num_wakers: AtomicU32::new(
79                Ghost(()),
80                0,
81                Tracked(WaitQueueGhost { queued_wakers: seq![] }),
82            ),
83            wakers: SpinLock::new(VecDeque::new()),
84        }
85    }
86
87    /// Waits until some condition is met.
88    ///
89    /// This method takes a closure that tests a user-given condition.
90    /// The method only returns if the condition returns `Some(_)`.
91    /// A waker thread should first make the condition `Some(_)`, then invoke the
92    /// `wake`-family method. This ordering is important to ensure that waiter
93    /// threads do not lose any wakeup notifications.
94    ///
95    /// By taking a condition closure, this wait-wakeup mechanism becomes
96    /// more efficient and robust.
97    #[track_caller]
98    #[verus_spec(ret =>
99        requires
100            cond.requires(()),
101        ensures
102            cond.ensures((), Some(ret)),
103    )]
104    #[verifier::exec_allows_no_decreases_clause]
105    pub fn wait_until<F, R>(&self, mut cond: F) -> R where F: FnMut() -> Option<R> {
106        if let Some(res) = cond() {
107            return res;
108        }
109        let (waiter, _) = Waiter::new_pair();
110        #[verus_spec(invariant
111            cond.requires(()),
112        )]
113        loop {
114            self.enqueue(waiter.waker());
115            if let Some(res) = cond() {
116                assert(cond.ensures((), Some(res)));
117                proof! { admit(); } // FIXME: https://github.com/verus-lang/verus/issues/2295
118                return res;
119            }
120            waiter.wait();
121        }
122    }
123
124    /// Wakes up one waiting thread, if there is one at the point of time when this method is
125    /// called, returning whether such a thread was woken up.
126    #[verifier::external_body]
127    pub fn wake_one(&self) -> (r: bool) {
128        // Fast path
129        if self.is_empty() {
130            return false;
131        }
132        loop {
133            let mut wakers = self.wakers.lock();
134            let Some(waker) = wakers.pop_front() else {
135                return false;
136            };
137            atomic_with_ghost! {
138                self.num_wakers => fetch_sub(1);
139                update prev -> next;
140                ghost g => {
141                    g = WaitQueueGhost { queued_wakers: g.queued_wakers.drop_first() };
142                }
143            };
144            // Avoid holding lock when calling `wake_up`
145            drop(wakers);
146
147            if waker.wake_up() {
148                return true;
149            }
150        }
151    }
152
153    /// Wakes up all waiting threads, returning the number of threads that were woken up.
154    #[verifier::external_body]
155    pub fn wake_all(&self) -> (r: usize) {
156        // Fast path
157        if self.is_empty() {
158            return 0;
159        }
160        let mut num_woken = 0;
161
162        loop {
163            let mut wakers = self.wakers.lock();
164            let Some(waker) = wakers.pop_front() else {
165                break ;
166            };
167            atomic_with_ghost! {
168                self.num_wakers => fetch_sub(1);
169                update prev -> next;
170                ghost g => {
171                    g = WaitQueueGhost { queued_wakers: g.queued_wakers.drop_first() };
172                }
173            };
174            // Avoid holding lock when calling `wake_up`
175            drop(wakers);
176
177            if waker.wake_up() {
178                num_woken += 1;
179            }
180        }
181
182        num_woken
183    }
184
185    #[verifier::external_body]
186    fn is_empty(&self) -> bool {
187        self.num_wakers.load() == 0
188    }
189
190    /// Enqueues the input [`Waker`] to the wait queue.
191    #[doc(hidden)]
192    #[verifier::external_body]
193    pub fn enqueue(&self, waker: Arc<Waker>) {
194        let mut wakers = self.wakers.lock();
195        wakers.push_back(waker);
196        atomic_with_ghost! {
197            self.num_wakers => fetch_add(1);
198            update prev -> next;
199            ghost g => {
200                g = WaitQueueGhost { queued_wakers: g.queued_wakers.push(waker.id()) };
201            }
202        };
203    }
204}
205
206impl Default for WaitQueue {
207    #[verifier::external_body]
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213/// A waiter that can put the current thread to sleep until it is woken up by the associated
214/// [`Waker`].
215///
216/// By definition, a waiter belongs to the current thread, so it cannot be sent to another thread
217/// and its reference cannot be shared between threads.
218pub struct Waiter {
219    waker: Arc<Waker>,
220}
221
222impl !Send for Waiter {
223
224}
225
226impl !Sync for Waiter {
227
228}
229
230impl Waiter {
231    /// Checks if the input waker is the associated waker of the current waiter.
232    pub closed spec fn rel_waker(self, waker: Arc<Waker>) -> bool {
233        self.waker == waker
234    }
235
236    /// Abstract identity of the paired waker.
237    pub closed spec fn waker_id(self) -> int {
238        self.waker.id()
239    }
240}
241
242/// A waker that can wake up the associated [`Waiter`].
243///
244/// A waker can be created by calling [`Waiter::new_pair`]. This method creates an `Arc<Waker>` that can
245/// be used across different threads.
246pub struct Waker {
247    has_woken: AtomicBool,
248    task: Arc<Task>,
249    v_id: Ghost<int>,
250}
251
252impl Waker {
253    /// Abstract identity used by the queue model.
254    pub closed spec fn id(self) -> int {
255        self.v_id@
256    }
257}
258
259#[verus_verify]
260impl Waiter {
261    /// Creates a waiter and its associated [`Waker`].
262    #[verus_spec(ret =>
263        ensures
264            ret.0.rel_waker(ret.1),
265    )]
266    pub fn new_pair() -> (Self, Arc<Waker>) {
267        proof_decl! {
268            let ghost waker_id: int = arbitrary();
269        }
270        let waker = Arc::new(
271            Waker {
272                has_woken: AtomicBool::new(false),
273                // task: Task::current().unwrap().cloned(),
274                task: Arc::new(Task {  }),
275                v_id: Ghost(waker_id),
276            },
277        );
278        let waiter = Self { waker: waker.clone() };
279        (waiter, waker)
280    }
281
282    /// Waits until the waiter is woken up by calling [`Waker::wake_up`] on the associated
283    /// [`Waker`].
284    ///
285    /// This method returns immediately if the waiter has been woken since the end of the last call
286    /// to this method (or since the waiter was created, if this method has not been called
287    /// before). Otherwise, it puts the current thread to sleep until the waiter is woken up.
288    #[track_caller]
289    pub fn wait(&self) {
290        self.waker.do_wait();
291    }
292
293    /// Waits until some condition is met or the cancel condition becomes true.
294    ///
295    /// This method will return `Ok(_)` if the condition returns `Some(_)`, and will stop waiting
296    /// if the cancel condition returns `Err(_)`. In this situation, this method will return the `Err(_)`
297    /// generated by the cancel condition.
298    #[verus_spec(ret =>
299        requires
300            cond.requires(()),
301            cancel_cond.requires(()),
302        ensures
303            match ret {
304                Ok(res) => cond.ensures((),Some(res)),
305                Err(e) => cancel_cond.ensures((), Err(e)),
306            },
307    )]
308    #[track_caller]
309    #[verifier::exec_allows_no_decreases_clause]
310    pub fn wait_until_or_cancelled<F, R, FCancel, E>(
311        &self,
312        mut cond: F,
313        cancel_cond: FCancel,
314    ) -> core::result::Result<R, E> where
315        F: FnMut() -> Option<R>,
316        FCancel: Fn() -> core::result::Result<(), E>,
317     {
318        let mut cond = cond;
319        #[verus_spec(invariant
320            cond.requires(()),
321            cancel_cond.requires(()),
322        )]
323        loop {
324            if let Some(res) = cond() {
325                assert(cond.ensures((), Some(res)));
326                proof! { admit(); }  // FIXME: https://github.com/verus-lang/verus/issues/2295
327                return Ok(res);
328            };
329            if let Err(e) = cancel_cond() {
330                // Close the waker and check again to avoid missing a wake event.
331                self.waker.close();
332                proof! { admit(); }  // FIXME: https://github.com/verus-lang/verus/issues/2295
333                return cond().ok_or(e);
334            }
335            self.wait();
336        }
337    }
338
339    /// Gets the associated [`Waker`] of the current waiter.
340    #[verus_spec(ret =>
341        ensures
342            self.rel_waker(ret),
343    )]
344    pub fn waker(&self) -> Arc<Waker> {
345        self.waker.clone()
346    }
347
348    /// Returns the task that the associated waker will attempt to wake up.
349    pub fn task(&self) -> &Arc<Task> {
350        &self.waker.task
351    }
352}
353
354impl Drop for Waiter {
355    #[verifier::external_body]
356    fn drop(&mut self)
357        opens_invariants none
358        no_unwind
359    {
360        // When dropping the waiter, we need to close the waker to ensure that if someone wants to
361        // wake up the waiter afterwards, they will perform a no-op.
362        self.waker.close();
363    }
364}
365
366impl Waker {
367    /// Wakes up the associated [`Waiter`].
368    ///
369    /// This method returns `true` if the waiter is woken by this call. It returns `false` if the
370    /// waiter has already been woken by a previous call to the method, or if the waiter has been
371    /// dropped.
372    ///
373    /// Note that if this method returns `true`, it implies that the wake event will be properly
374    /// delivered, _or_ that the waiter will be dropped after being woken. It's up to the caller to
375    /// handle the latter case properly to avoid missing the wake event.
376    #[verifier::external_body]
377    pub fn wake_up(&self) -> bool {
378        if self.has_woken.swap(true, Ordering::Release) {
379            return false;
380        }
381        scheduler::unpark_target(self.task.clone());
382
383        true
384    }
385
386    #[track_caller]
387    #[verifier::external_body]
388    fn do_wait(&self) {
389        while !self.has_woken.swap(false, Ordering::Acquire) {
390            scheduler::park_current(|| self.has_woken.load(Ordering::Acquire));
391        }
392    }
393
394    #[verifier::external_body]
395    fn close(&self) {
396        // This must use `Ordering::Acquire`, although we do not care about the return value. See
397        // the memory order explanation at the top of the file for details.
398        let _ = self.has_woken.swap(true, Ordering::Acquire);
399    }
400}
401
402} // verus!
403#[cfg(ktest)]
404mod test {
405    use super::*;
406    use crate::{prelude::*, task::TaskOptions};
407
408    fn queue_wake<F>(wake: F)
409    where
410        F: Fn(&WaitQueue) + Sync + Send + 'static,
411    {
412        let queue = Arc::new(WaitQueue::new());
413        let queue_cloned = queue.clone();
414
415        let cond = Arc::new(AtomicBool::new(false));
416        let cond_cloned = cond.clone();
417
418        TaskOptions::new(move || {
419            Task::yield_now();
420
421            cond_cloned.store(true, Ordering::Relaxed);
422            wake(&queue_cloned);
423        })
424        .data(())
425        .spawn()
426        .unwrap();
427
428        queue.wait_until(|| cond.load(Ordering::Relaxed).then_some(()));
429
430        assert!(cond.load(Ordering::Relaxed));
431    }
432
433    #[ktest]
434    fn queue_wake_one() {
435        queue_wake(|queue| {
436            queue.wake_one();
437        });
438    }
439
440    #[ktest]
441    fn queue_wake_all() {
442        queue_wake(|queue| {
443            queue.wake_all();
444        });
445    }
446
447    #[ktest]
448    fn waiter_wake_twice() {
449        let (_waiter, waker) = Waiter::new_pair();
450
451        assert!(waker.wake_up());
452        assert!(!waker.wake_up());
453    }
454
455    #[ktest]
456    fn waiter_wake_drop() {
457        let (waiter, waker) = Waiter::new_pair();
458
459        drop(waiter);
460        assert!(!waker.wake_up());
461    }
462
463    #[ktest]
464    fn waiter_wake_async() {
465        let (waiter, waker) = Waiter::new_pair();
466
467        let cond = Arc::new(AtomicBool::new(false));
468        let cond_cloned = cond.clone();
469
470        TaskOptions::new(move || {
471            Task::yield_now();
472
473            cond_cloned.store(true, Ordering::Relaxed);
474            assert!(waker.wake_up());
475        })
476        .data(())
477        .spawn()
478        .unwrap();
479
480        waiter.wait();
481
482        assert!(cond.load(Ordering::Relaxed));
483    }
484
485    #[ktest]
486    fn waiter_wake_reorder() {
487        let (waiter, waker) = Waiter::new_pair();
488
489        let cond = Arc::new(AtomicBool::new(false));
490        let cond_cloned = cond.clone();
491
492        let (waiter2, waker2) = Waiter::new_pair();
493
494        let cond2 = Arc::new(AtomicBool::new(false));
495        let cond2_cloned = cond2.clone();
496
497        TaskOptions::new(move || {
498            Task::yield_now();
499
500            cond2_cloned.store(true, Ordering::Relaxed);
501            assert!(waker2.wake_up());
502
503            Task::yield_now();
504
505            cond_cloned.store(true, Ordering::Relaxed);
506            assert!(waker.wake_up());
507        })
508        .data(())
509        .spawn()
510        .unwrap();
511
512        waiter.wait();
513        assert!(cond.load(Ordering::Relaxed));
514
515        waiter2.wait();
516        assert!(cond2.load(Ordering::Relaxed));
517    }
518}