1use core::{
4 cell::UnsafeCell,
5 fmt,
6 ops::{Deref, DerefMut},
7 sync::atomic::{AtomicBool, Ordering},
8};
9
10use super::WaitQueue;
11
12pub struct Mutex<T: ?Sized> {
14 lock: AtomicBool,
15 queue: WaitQueue,
16 val: UnsafeCell<T>,
17}
18
19impl<T> Mutex<T> {
20 pub const fn new(val: T) -> Self {
22 Self {
23 lock: AtomicBool::new(false),
24 queue: WaitQueue::new(),
25 val: UnsafeCell::new(val),
26 }
27 }
28}
29
30impl<T: ?Sized> Mutex<T> {
31 #[track_caller]
35 pub fn lock(&self) -> MutexGuard<'_, T> {
36 self.queue.wait_until(|| self.try_lock())
37 }
38
39 pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
41 self.acquire_lock()
45 .then(|| unsafe { MutexGuard::new(self) })
46 }
47
48 pub fn get_mut(&mut self) -> &mut T {
53 self.val.get_mut()
54 }
55
56 fn unlock(&self) {
58 self.release_lock();
59 self.queue.wake_one();
60 }
61
62 fn acquire_lock(&self) -> bool {
63 self.lock
64 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
65 .is_ok()
66 }
67
68 fn release_lock(&self) {
69 self.lock.store(false, Ordering::Release);
70 }
71}
72
73impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
74 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75 fmt::Debug::fmt(&self.val, f)
76 }
77}
78
79unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
80unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
81
82#[clippy::has_significant_drop]
84#[must_use]
85pub struct MutexGuard<'a, T: ?Sized> {
86 mutex: &'a Mutex<T>,
87}
88
89impl<'a, T: ?Sized> MutexGuard<'a, T> {
90 unsafe fn new(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
95 MutexGuard { mutex }
96 }
97}
98
99impl<T: ?Sized> Deref for MutexGuard<'_, T> {
100 type Target = T;
101
102 fn deref(&self) -> &Self::Target {
103 unsafe { &*self.mutex.val.get() }
104 }
105}
106
107impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
108 fn deref_mut(&mut self) -> &mut Self::Target {
109 unsafe { &mut *self.mutex.val.get() }
110 }
111}
112
113impl<T: ?Sized> Drop for MutexGuard<'_, T> {
114 fn drop(&mut self) {
115 self.mutex.unlock();
116 }
117}
118
119impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
120 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
121 fmt::Debug::fmt(&**self, f)
122 }
123}
124
125impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
126
127unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
128
129impl<'a, T: ?Sized> MutexGuard<'a, T> {
130 pub fn get_lock(guard: &MutexGuard<'a, T>) -> &'a Mutex<T> {
132 guard.mutex
133 }
134}
135
136#[cfg(ktest)]
137mod test {
138 use super::*;
139 use crate::prelude::*;
140
141 #[ktest]
143 fn test_mutex_try_lock_does_not_unlock() {
144 let lock = Mutex::new(0);
145 assert!(!lock.lock.load(Ordering::Relaxed));
146
147 let guard1 = lock.lock();
149 assert!(lock.lock.load(Ordering::Relaxed));
150
151 assert!(lock.try_lock().is_none());
153 assert!(lock.lock.load(Ordering::Relaxed));
154
155 drop(guard1);
157 }
158}