1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
//! A common utility for building synchronization primitives.
//!
//! When an async operation is blocked, it needs to register itself somewhere so that it can be
//! notified later on. The `WakerSet` type helps with keeping track of such async operations and
//! notifying them when they may make progress.

use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Waker};

use crossbeam_utils::Backoff;
use slab::Slab;

/// Set when the entry list is locked.
#[allow(clippy::identity_op)]
const LOCKED: usize = 1 << 0;

/// Set when there is at least one entry that has already been notified.
const NOTIFIED: usize = 1 << 1;

/// Set when there is at least one notifiable entry.
const NOTIFIABLE: usize = 1 << 2;

/// Inner representation of `WakerSet`.
struct Inner {
    /// A list of entries in the set.
    ///
    /// Each entry has an optional waker associated with the task that is executing the operation.
    /// If the waker is set to `None`, that means the task has been woken up but hasn't removed
    /// itself from the `WakerSet` yet.
    ///
    /// The key of each entry is its index in the `Slab`.
    entries: Slab<Option<Waker>>,

    /// The number of notifiable entries.
    notifiable: usize,
}

/// A set holding wakers.
pub struct WakerSet {
    /// Holds three bits: `LOCKED`, `NOTIFY_ONE`, and `NOTIFY_ALL`.
    flag: AtomicUsize,

    /// A set holding wakers.
    inner: UnsafeCell<Inner>,
}

impl WakerSet {
    /// Creates a new `WakerSet`.
    #[inline]
    pub fn new() -> WakerSet {
        WakerSet {
            flag: AtomicUsize::new(0),
            inner: UnsafeCell::new(Inner {
                entries: Slab::new(),
                notifiable: 0,
            }),
        }
    }

    /// Inserts a waker for a blocked operation and returns a key associated with it.
    #[cold]
    pub fn insert(&self, cx: &Context<'_>) -> usize {
        let w = cx.waker().clone();
        let mut inner = self.lock();

        let key = inner.entries.insert(Some(w));
        inner.notifiable += 1;
        key
    }

    /// If the waker for this key is still waiting for a notification, then update
    /// the waker for the entry, and return false. If the waker has been notified,
    /// treat the entry as completed and return true.
    #[cfg(feature = "unstable")]
    pub fn remove_if_notified(&self, key: usize, cx: &Context<'_>) -> bool {
        let mut inner = self.lock();

        match &mut inner.entries[key] {
            None => {
                inner.entries.remove(key);
                true
            }
            Some(w) => {
                // We were never woken, so update instead
                if !w.will_wake(cx.waker()) {
                    *w = cx.waker().clone();
                }
                false
            }
        }
    }

    /// Removes the waker of a cancelled operation.
    ///
    /// Returns `true` if another blocked operation from the set was notified.
    #[cold]
    pub fn cancel(&self, key: usize) -> bool {
        let mut inner = self.lock();

        match inner.entries.remove(key) {
            Some(_) => inner.notifiable -= 1,
            None => {
                // The operation was cancelled and notified so notify another operation instead.
                for (_, opt_waker) in inner.entries.iter_mut() {
                    // If there is no waker in this entry, that means it was already woken.
                    if let Some(w) = opt_waker.take() {
                        w.wake();
                        inner.notifiable -= 1;
                        return true;
                    }
                }
            }
        }

        false
    }

    /// Notifies one additional blocked operation.
    ///
    /// Returns `true` if an operation was notified.
    #[inline]
    #[cfg(feature = "unstable")]
    pub fn notify_one(&self) -> bool {
        // Use `SeqCst` ordering to synchronize with `Lock::drop()`.
        if self.flag.load(Ordering::SeqCst) & NOTIFIABLE != 0 {
            self.notify(Notify::One)
        } else {
            false
        }
    }

    /// Notifies all blocked operations.
    ///
    /// Returns `true` if at least one operation was notified.
    #[inline]
    pub fn notify_all(&self) -> bool {
        // Use `SeqCst` ordering to synchronize with `Lock::drop()`.
        if self.flag.load(Ordering::SeqCst) & NOTIFIABLE != 0 {
            self.notify(Notify::All)
        } else {
            false
        }
    }

    /// Notifies blocked operations, either one or all of them.
    ///
    /// Returns `true` if at least one operation was notified.
    #[cold]
    fn notify(&self, n: Notify) -> bool {
        let mut inner = &mut *self.lock();
        let mut notified = false;

        for (_, opt_waker) in inner.entries.iter_mut() {
            // If there is no waker in this entry, that means it was already woken.
            if let Some(w) = opt_waker.take() {
                w.wake();
                inner.notifiable -= 1;
                notified = true;

                if n == Notify::One {
                    break;
                }
            }

            if n == Notify::Any {
                break;
            }
        }

        notified
    }

    /// Locks the list of entries.
    fn lock(&self) -> Lock<'_> {
        let backoff = Backoff::new();
        while self.flag.fetch_or(LOCKED, Ordering::Acquire) & LOCKED != 0 {
            backoff.snooze();
        }
        Lock { waker_set: self }
    }
}

/// A guard holding a `WakerSet` locked.
struct Lock<'a> {
    waker_set: &'a WakerSet,
}

impl Drop for Lock<'_> {
    #[inline]
    fn drop(&mut self) {
        let mut flag = 0;

        // Set the `NOTIFIED` flag if there is at least one notified entry.
        if self.entries.len() - self.notifiable > 0 {
            flag |= NOTIFIED;
        }

        // Set the `NOTIFIABLE` flag if there is at least one notifiable entry.
        if self.notifiable > 0 {
            flag |= NOTIFIABLE;
        }

        // Use `SeqCst` ordering to synchronize with `WakerSet::lock_to_notify()`.
        self.waker_set.flag.store(flag, Ordering::SeqCst);
    }
}

impl Deref for Lock<'_> {
    type Target = Inner;

    #[inline]
    fn deref(&self) -> &Inner {
        unsafe { &*self.waker_set.inner.get() }
    }
}

impl DerefMut for Lock<'_> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Inner {
        unsafe { &mut *self.waker_set.inner.get() }
    }
}

/// Notification strategy.
#[derive(Clone, Copy, Eq, PartialEq)]
enum Notify {
    /// Make sure at least one entry is notified.
    Any,
    /// Notify one additional entry.
    One,
    /// Notify all entries.
    All,
}