#![allow(unused_variables)]
use core::{
    fmt::{self},
    sync::atomic::{AtomicU64, Ordering::Relaxed},
};
use align_ext::AlignExt;
use crate::prelude::*;
pub struct AtomicBits {
    num_bits: usize,
    u64s: Box<[AtomicU64]>,
}
impl AtomicBits {
    pub fn new_zeroes(num_bits: usize) -> Self {
        Self::new(0, num_bits)
    }
    pub fn new_ones(num_bits: usize) -> Self {
        Self::new(!0, num_bits)
    }
    fn new(u64_val: u64, num_bits: usize) -> Self {
        let num_u64s = num_bits.align_up(64) / 64;
        let u64s = {
            let mut u64s = Vec::with_capacity(num_u64s);
            for _ in 0..num_u64s {
                u64s.push(AtomicU64::new(u64_val));
            }
            u64s.into_boxed_slice()
        };
        Self { num_bits, u64s }
    }
    pub fn len(&self) -> usize {
        self.num_bits
    }
    pub fn get(&self, index: usize) -> bool {
        assert!(index < self.num_bits);
        let i = index / 64;
        let j = index % 64;
        let u64_atomic = unsafe { self.u64s.get_unchecked(i) };
        (u64_atomic.load(Relaxed) & 1 << j) != 0
    }
    pub fn set(&self, index: usize, new_bit: bool) {
        assert!(index < self.num_bits);
        let i = index / 64;
        let j = index % 64;
        let u64_atomic = unsafe { self.u64s.get_unchecked(i) };
        if new_bit {
            u64_atomic.fetch_or(1 << j, Relaxed);
        } else {
            u64_atomic.fetch_and(!(1 << j), Relaxed);
        }
    }
    pub fn clear(&self) {
        todo!()
    }
    pub fn is_full(&self) -> bool {
        self.match_pattern(!0)
    }
    pub fn is_empty(&self) -> bool {
        self.match_pattern(0)
    }
    fn match_pattern(&self, pattern: u64) -> bool {
        todo!()
    }
    pub fn iter(&self) -> Iter<'_> {
        Iter::new(self)
    }
    pub fn iter_ones(&self) -> OnesIter<'_> {
        OnesIter::new(self)
    }
    pub fn iter_zeroes(&self) -> ZeroesIter<'_> {
        ZeroesIter::new(self)
    }
}
pub struct Iter<'a> {
    bits: &'a AtomicBits,
    bit_i: usize,
}
impl<'a> Iter<'a> {
    fn new(bits: &'a AtomicBits) -> Self {
        Self { bits, bit_i: 0 }
    }
}
impl<'a> Iterator for Iter<'a> {
    type Item = bool;
    fn next(&mut self) -> Option<bool> {
        if self.bit_i < self.bits.len() {
            let bit = self.bits.get(self.bit_i);
            self.bit_i += 1;
            Some(bit)
        } else {
            None
        }
    }
}
pub struct OnesIter<'a> {
    bits: &'a AtomicBits,
    u64_idx: usize,
    u64_val: u64,
    num_garbage_bits_in_last_u64: u8,
}
impl<'a> OnesIter<'a> {
    fn new(bits: &'a AtomicBits) -> Self {
        let num_garbage_bits_in_last_u64 = {
            if bits.len() % 64 != 0 {
                64 - ((bits.len() % 64) as u8)
            } else {
                0
            }
        };
        let mut new_self = Self {
            bits,
            u64_idx: 0,
            u64_val: 0, num_garbage_bits_in_last_u64,
        };
        new_self.u64_val = new_self.get_u64_val(0);
        new_self
    }
    fn get_u64_val(&self, idx: usize) -> u64 {
        let mut u64_val = self.bits.u64s[idx].load(Relaxed);
        if idx == self.bits.u64s.len() - 1 && self.num_garbage_bits_in_last_u64 > 0 {
            let num_valid_bits_in_last_u64 = 64 - self.num_garbage_bits_in_last_u64;
            let valid_bits_mask = (1 << num_valid_bits_in_last_u64) - 1;
            u64_val &= valid_bits_mask;
        }
        u64_val
    }
}
impl<'a> Iterator for OnesIter<'a> {
    type Item = usize;
    fn next(&mut self) -> Option<usize> {
        loop {
            if self.u64_idx >= self.bits.u64s.len() {
                return None;
            }
            let first_one_in_u64 = self.u64_val.trailing_zeros() as usize;
            if first_one_in_u64 < 64 {
                self.u64_val &= !(1 << first_one_in_u64);
                let one_pos = self.u64_idx * 64 + first_one_in_u64;
                return Some(one_pos);
            }
            self.u64_idx += 1;
            if self.u64_idx < self.bits.u64s.len() {
                self.u64_val = self.get_u64_val(self.u64_idx);
            }
        }
    }
}
pub struct ZeroesIter<'a> {
    bits: &'a AtomicBits,
    u64_idx: usize,
    u64_val: u64,
    num_garbage_bits_in_last_u64: u8,
}
impl<'a> ZeroesIter<'a> {
    fn new(bits: &'a AtomicBits) -> Self {
        let num_garbage_bits_in_last_u64 = {
            if bits.len() % 64 != 0 {
                64 - ((bits.len() % 64) as u8)
            } else {
                0
            }
        };
        let mut new_self = Self {
            bits,
            u64_idx: 0,
            u64_val: 0, num_garbage_bits_in_last_u64,
        };
        new_self.u64_val = new_self.get_u64_val(0);
        new_self
    }
    fn get_u64_val(&self, idx: usize) -> u64 {
        let mut u64_val = self.bits.u64s[idx].load(Relaxed);
        if idx == self.bits.u64s.len() - 1 && self.num_garbage_bits_in_last_u64 > 0 {
            let num_valid_bits_in_last_u64 = 64 - self.num_garbage_bits_in_last_u64;
            let garbage_bits_mask = !((1 << num_valid_bits_in_last_u64) - 1);
            u64_val |= garbage_bits_mask;
        }
        u64_val
    }
}
impl<'a> Iterator for ZeroesIter<'a> {
    type Item = usize;
    fn next(&mut self) -> Option<usize> {
        loop {
            if self.u64_idx >= self.bits.u64s.len() {
                return None;
            }
            let first_zero_in_u64 = self.u64_val.trailing_ones() as usize;
            if first_zero_in_u64 < 64 {
                self.u64_val |= 1 << first_zero_in_u64;
                let one_pos = self.u64_idx * 64 + first_zero_in_u64;
                return Some(one_pos);
            }
            self.u64_idx += 1;
            if self.u64_idx < self.bits.u64s.len() {
                self.u64_val = self.get_u64_val(self.u64_idx);
            }
        }
    }
}
impl Clone for AtomicBits {
    fn clone(&self) -> Self {
        let num_bits = self.num_bits;
        let num_u64s = self.u64s.len();
        let u64s = {
            let mut u64s = Vec::with_capacity(num_u64s);
            for u64_i in 0..num_u64s {
                let u64_val = self.u64s[u64_i].load(Relaxed);
                u64s.push(AtomicU64::new(u64_val));
            }
            u64s.into_boxed_slice()
        };
        Self { num_bits, u64s }
    }
}
impl fmt::Debug for AtomicBits {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "AtomicBits(")?;
        for bit in self.iter() {
            if bit {
                write!(f, "1")?;
            } else {
                write!(f, "0")?;
            }
        }
        write!(f, ")")
    }
}
#[cfg(ktest)]
mod test {
    use super::*;
    #[ktest]
    fn new() {
        let bits = AtomicBits::new_zeroes(1);
        assert!(bits.len() == 1);
        let bits = AtomicBits::new_zeroes(128);
        assert!(bits.len() == 128);
        let bits = AtomicBits::new_ones(7);
        assert!(bits.len() == 7);
        let bits = AtomicBits::new_zeroes(65);
        assert!(bits.len() == 65);
    }
    #[ktest]
    fn set_get() {
        let bits = AtomicBits::new_zeroes(128);
        for i in 0..bits.len() {
            assert!(bits.get(i) == false);
            bits.set(i, true);
            assert!(bits.get(i) == true);
            bits.set(i, false);
            assert!(bits.get(i) == false);
        }
        let bits = AtomicBits::new_ones(128);
        for i in 0..bits.len() {
            assert!(bits.get(i) == true);
            bits.set(i, false);
            assert!(bits.get(i) == false);
            bits.set(i, true);
            assert!(bits.get(i) == true);
        }
    }
    #[ktest]
    fn iter_ones() {
        let bits = AtomicBits::new_zeroes(1);
        assert!(bits.iter_ones().count() == 0);
        let bits = AtomicBits::new_zeroes(400);
        assert!(bits.iter_ones().count() == 0);
        let bits = AtomicBits::new_ones(1);
        assert!(bits.iter_ones().count() == 1);
        let bits = AtomicBits::new_ones(24);
        assert!(bits.iter_ones().count() == 24);
        let bits = AtomicBits::new_ones(64);
        assert!(bits.iter_ones().count() == 64);
        let bits = AtomicBits::new_ones(77);
        assert!(bits.iter_ones().count() == 77);
        let bits = AtomicBits::new_ones(128);
        assert!(bits.iter_ones().count() == 128);
        let bits = AtomicBits::new_zeroes(8);
        bits.set(1, true);
        bits.set(3, true);
        bits.set(5, true);
        assert!(bits.iter_ones().count() == 3);
    }
    #[ktest]
    fn iter_zeroes() {
        let bits = AtomicBits::new_ones(1);
        assert!(bits.iter_zeroes().count() == 0);
        let bits = AtomicBits::new_ones(130);
        assert!(bits.iter_zeroes().count() == 0);
        let bits = AtomicBits::new_zeroes(1);
        assert!(bits.iter_zeroes().count() == 1);
        let bits = AtomicBits::new_zeroes(24);
        assert!(bits.iter_zeroes().count() == 24);
        let bits = AtomicBits::new_zeroes(64);
        assert!(bits.iter_zeroes().count() == 64);
        let bits = AtomicBits::new_zeroes(77);
        assert!(bits.iter_zeroes().count() == 77);
        let bits = AtomicBits::new_zeroes(128);
        assert!(bits.iter_zeroes().count() == 128);
        let bits = AtomicBits::new_ones(96);
        bits.set(1, false);
        bits.set(3, false);
        bits.set(5, false);
        bits.set(64, false);
        bits.set(76, false);
        assert!(bits.iter_zeroes().count() == 5);
    }
    #[ktest]
    fn iter() {
        let bits = AtomicBits::new_zeroes(7);
        assert!(bits.iter().all(|bit| bit == false));
        let bits = AtomicBits::new_ones(128);
        assert!(bits.iter().all(|bit| bit == true));
    }
}