ostd/sync/
wait.rs

1// SPDX-License-Identifier: MPL-2.0
2
3use alloc::{collections::VecDeque, sync::Arc};
4use core::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5
6use super::{LocalIrqDisabled, SpinLock};
7use crate::task::{Task, scheduler};
8
9// # Explanation on the memory orders
10//
11// ```
12// [CPU 1 (the waker)]     [CPU 2 (the waiter)]
13// cond = true;
14// wake_up();
15//                         wait();
16//                         if cond { /* .. */ }
17// ```
18//
19// As soon as the waiter is woken up by the waker, it must see the true condition. This is
20// trivially satisfied if `wake_up()` and `wait()` synchronize with a lock. But if they synchronize
21// with an atomic variable, `wake_up()` must access the variable with `Ordering::Release` and
22// `wait()` must access the variable with `Ordering::Acquire`.
23//
24// Examples of `wake_up()`:
25//  - `WaitQueue::wake_one()`
26//  - `WaitQueue::wake_all()`
27//  - `Waker::wake_up()`
28//
29// Examples of `wait()`:
30//  - `WaitQueue::wait_until()`
31//  - `Waiter::wait()`
32//  - `Waiter::drop()`
33//
34// Note that dropping a waiter must be treated as a `wait()` with zero timeout, because we need to
35// make sure that the wake event isn't lost in this case.
36
37/// A wait queue.
38///
39/// One may wait on a wait queue to put its executing thread to sleep.
40/// Multiple threads may be the waiters of a wait queue.
41/// Other threads may invoke the `wake`-family methods of a wait queue to
42/// wake up one or many waiting threads.
43pub struct WaitQueue {
44    // A copy of `wakers.len()`, used for the lock-free fast path in `wake_one` and `wake_all`.
45    num_wakers: AtomicU32,
46    wakers: SpinLock<VecDeque<Arc<Waker>>, LocalIrqDisabled>,
47}
48
49impl WaitQueue {
50    /// Creates a new, empty wait queue.
51    pub const fn new() -> Self {
52        WaitQueue {
53            num_wakers: AtomicU32::new(0),
54            wakers: SpinLock::new(VecDeque::new()),
55        }
56    }
57
58    /// Waits until some condition is met.
59    ///
60    /// This method takes a closure that tests a user-given condition.
61    /// The method only returns if the condition returns `Some(_)`.
62    /// A waker thread should first make the condition `Some(_)`, then invoke the
63    /// `wake`-family method. This ordering is important to ensure that waiter
64    /// threads do not lose any wakeup notifications.
65    ///
66    /// By taking a condition closure, this wait-wakeup mechanism becomes
67    /// more efficient and robust.
68    #[track_caller]
69    pub fn wait_until<F, R>(&self, mut cond: F) -> R
70    where
71        F: FnMut() -> Option<R>,
72    {
73        if let Some(res) = cond() {
74            return res;
75        }
76
77        let (waiter, _) = Waiter::new_pair();
78        let cond = || {
79            self.enqueue(waiter.waker());
80            cond()
81        };
82        waiter
83            .wait_until_or_cancelled(cond, || Ok::<(), ()>(()))
84            .unwrap()
85    }
86
87    /// Wakes up one waiting thread, if there is one at the point of time when this method is
88    /// called, returning whether such a thread was woken up.
89    pub fn wake_one(&self) -> bool {
90        // Fast path
91        if self.is_empty() {
92            return false;
93        }
94
95        loop {
96            let mut wakers = self.wakers.lock();
97            let Some(waker) = wakers.pop_front() else {
98                return false;
99            };
100            self.num_wakers.fetch_sub(1, Ordering::Release);
101            // Avoid holding lock when calling `wake_up`
102            drop(wakers);
103
104            if waker.wake_up() {
105                return true;
106            }
107        }
108    }
109
110    /// Wakes up all waiting threads, returning the number of threads that were woken up.
111    pub fn wake_all(&self) -> usize {
112        // Fast path
113        if self.is_empty() {
114            return 0;
115        }
116
117        let mut num_woken = 0;
118
119        loop {
120            let mut wakers = self.wakers.lock();
121            let Some(waker) = wakers.pop_front() else {
122                break;
123            };
124            self.num_wakers.fetch_sub(1, Ordering::Release);
125            // Avoid holding lock when calling `wake_up`
126            drop(wakers);
127
128            if waker.wake_up() {
129                num_woken += 1;
130            }
131        }
132
133        num_woken
134    }
135
136    fn is_empty(&self) -> bool {
137        // On x86-64, this generates `mfence; mov`, which is exactly the right way to implement
138        // atomic loading with `Ordering::Release`. It performs much better than naively
139        // translating `fetch_add(0)` to `lock; xadd`.
140        self.num_wakers.fetch_add(0, Ordering::Release) == 0
141    }
142
143    /// Enqueues the input [`Waker`] to the wait queue.
144    #[doc(hidden)]
145    pub fn enqueue(&self, waker: Arc<Waker>) {
146        let mut wakers = self.wakers.lock();
147        wakers.push_back(waker);
148        self.num_wakers.fetch_add(1, Ordering::Acquire);
149    }
150}
151
152impl Default for WaitQueue {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// A waiter that can put the current thread to sleep until it is woken up by the associated
159/// [`Waker`].
160///
161/// By definition, a waiter belongs to the current thread, so it cannot be sent to another thread
162/// and its reference cannot be shared between threads.
163pub struct Waiter {
164    waker: Arc<Waker>,
165}
166
167impl !Send for Waiter {}
168impl !Sync for Waiter {}
169
170/// A waker that can wake up the associated [`Waiter`].
171///
172/// A waker can be created by calling [`Waiter::new_pair`]. This method creates an `Arc<Waker>` that can
173/// be used across different threads.
174pub struct Waker {
175    has_woken: AtomicBool,
176    task: Arc<Task>,
177}
178
179impl Waiter {
180    /// Creates a waiter and its associated [`Waker`].
181    pub fn new_pair() -> (Self, Arc<Waker>) {
182        let waker = Arc::new(Waker {
183            has_woken: AtomicBool::new(false),
184            task: Task::current().unwrap().cloned(),
185        });
186        let waiter = Self {
187            waker: waker.clone(),
188        };
189        (waiter, waker)
190    }
191
192    /// Waits until the waiter is woken up by calling [`Waker::wake_up`] on the associated
193    /// [`Waker`].
194    ///
195    /// This method returns immediately if the waiter has been woken since the end of the last call
196    /// to this method (or since the waiter was created, if this method has not been called
197    /// before). Otherwise, it puts the current thread to sleep until the waiter is woken up.
198    #[track_caller]
199    pub fn wait(&self) {
200        self.waker.do_wait();
201    }
202
203    /// Waits until some condition is met or the cancel condition becomes true.
204    ///
205    /// This method will return `Ok(_)` if the condition returns `Some(_)`, and will stop waiting
206    /// if the cancel condition returns `Err(_)`. In this situation, this method will return the `Err(_)`
207    /// generated by the cancel condition.
208    #[track_caller]
209    pub fn wait_until_or_cancelled<F, R, FCancel, E>(
210        &self,
211        mut cond: F,
212        cancel_cond: FCancel,
213    ) -> core::result::Result<R, E>
214    where
215        F: FnMut() -> Option<R>,
216        FCancel: Fn() -> core::result::Result<(), E>,
217    {
218        loop {
219            if let Some(res) = cond() {
220                return Ok(res);
221            };
222
223            if let Err(e) = cancel_cond() {
224                // Close the waker and check again to avoid missing a wake event.
225                self.waker.close();
226                return cond().ok_or(e);
227            }
228
229            self.wait();
230        }
231    }
232
233    /// Gets the associated [`Waker`] of the current waiter.
234    pub fn waker(&self) -> Arc<Waker> {
235        self.waker.clone()
236    }
237
238    /// Returns the task that the associated waker will attempt to wake up.
239    pub fn task(&self) -> &Arc<Task> {
240        &self.waker.task
241    }
242}
243
244impl Drop for Waiter {
245    fn drop(&mut self) {
246        // When dropping the waiter, we need to close the waker to ensure that if someone wants to
247        // wake up the waiter afterwards, they will perform a no-op.
248        self.waker.close();
249    }
250}
251
252impl Waker {
253    /// Wakes up the associated [`Waiter`].
254    ///
255    /// This method returns `true` if the waiter is woken by this call. It returns `false` if the
256    /// waiter has already been woken by a previous call to the method, or if the waiter has been
257    /// dropped.
258    ///
259    /// Note that if this method returns `true`, it implies that the wake event will be properly
260    /// delivered, _or_ that the waiter will be dropped after being woken. It's up to the caller to
261    /// handle the latter case properly to avoid missing the wake event.
262    pub fn wake_up(&self) -> bool {
263        if self.has_woken.swap(true, Ordering::Release) {
264            return false;
265        }
266        scheduler::unpark_target(self.task.clone());
267
268        true
269    }
270
271    #[track_caller]
272    fn do_wait(&self) {
273        while !self.has_woken.swap(false, Ordering::Acquire) {
274            scheduler::park_current(|| self.has_woken.load(Ordering::Acquire));
275        }
276    }
277
278    fn close(&self) {
279        // This must use `Ordering::Acquire`, although we do not care about the return value. See
280        // the memory order explanation at the top of the file for details.
281        let _ = self.has_woken.swap(true, Ordering::Acquire);
282    }
283}
284
285#[cfg(ktest)]
286mod test {
287    use super::*;
288    use crate::{prelude::*, task::TaskOptions};
289
290    fn queue_wake<F>(wake: F)
291    where
292        F: Fn(&WaitQueue) + Sync + Send + 'static,
293    {
294        let queue = Arc::new(WaitQueue::new());
295        let queue_cloned = queue.clone();
296
297        let cond = Arc::new(AtomicBool::new(false));
298        let cond_cloned = cond.clone();
299
300        TaskOptions::new(move || {
301            Task::yield_now();
302
303            cond_cloned.store(true, Ordering::Relaxed);
304            wake(&queue_cloned);
305        })
306        .data(())
307        .spawn()
308        .unwrap();
309
310        queue.wait_until(|| cond.load(Ordering::Relaxed).then_some(()));
311
312        assert!(cond.load(Ordering::Relaxed));
313    }
314
315    #[ktest]
316    fn queue_wake_one() {
317        queue_wake(|queue| {
318            queue.wake_one();
319        });
320    }
321
322    #[ktest]
323    fn queue_wake_all() {
324        queue_wake(|queue| {
325            queue.wake_all();
326        });
327    }
328
329    #[ktest]
330    fn waiter_wake_twice() {
331        let (_waiter, waker) = Waiter::new_pair();
332
333        assert!(waker.wake_up());
334        assert!(!waker.wake_up());
335    }
336
337    #[ktest]
338    fn waiter_wake_drop() {
339        let (waiter, waker) = Waiter::new_pair();
340
341        drop(waiter);
342        assert!(!waker.wake_up());
343    }
344
345    #[ktest]
346    fn waiter_wake_async() {
347        let (waiter, waker) = Waiter::new_pair();
348
349        let cond = Arc::new(AtomicBool::new(false));
350        let cond_cloned = cond.clone();
351
352        TaskOptions::new(move || {
353            Task::yield_now();
354
355            cond_cloned.store(true, Ordering::Relaxed);
356            assert!(waker.wake_up());
357        })
358        .data(())
359        .spawn()
360        .unwrap();
361
362        waiter.wait();
363
364        assert!(cond.load(Ordering::Relaxed));
365    }
366
367    #[ktest]
368    fn waiter_wake_reorder() {
369        let (waiter, waker) = Waiter::new_pair();
370
371        let cond = Arc::new(AtomicBool::new(false));
372        let cond_cloned = cond.clone();
373
374        let (waiter2, waker2) = Waiter::new_pair();
375
376        let cond2 = Arc::new(AtomicBool::new(false));
377        let cond2_cloned = cond2.clone();
378
379        TaskOptions::new(move || {
380            Task::yield_now();
381
382            cond2_cloned.store(true, Ordering::Relaxed);
383            assert!(waker2.wake_up());
384
385            Task::yield_now();
386
387            cond_cloned.store(true, Ordering::Relaxed);
388            assert!(waker.wake_up());
389        })
390        .data(())
391        .spawn()
392        .unwrap();
393
394        waiter.wait();
395        assert!(cond.load(Ordering::Relaxed));
396
397        waiter2.wait();
398        assert!(cond2.load(Ordering::Relaxed));
399    }
400}