use candle::utils::get_num_threads;
use std::sync::Arc;
use std::thread;
pub trait Float:
    num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
{
}
impl Float for f32 {}
impl Float for f64 {}
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
    let n = inp.len();
    let zero = T::zero();
    if n == 1 {
        return vec![inp[0], zero];
    }
    if n % 2 == 1 {
        return dft(inp);
    }
    let mut out = vec![zero; n * 2];
    let mut even = Vec::with_capacity(n / 2);
    let mut odd = Vec::with_capacity(n / 2);
    for (i, &inp) in inp.iter().enumerate() {
        if i % 2 == 0 {
            even.push(inp)
        } else {
            odd.push(inp);
        }
    }
    let even_fft = fft(&even);
    let odd_fft = fft(&odd);
    let two_pi = T::PI() + T::PI();
    let n_t = T::from(n).unwrap();
    for k in 0..n / 2 {
        let k_t = T::from(k).unwrap();
        let theta = two_pi * k_t / n_t;
        let re = theta.cos();
        let im = -theta.sin();
        let re_odd = odd_fft[2 * k];
        let im_odd = odd_fft[2 * k + 1];
        out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
        out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
        out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
        out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
    }
    out
}
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
    let zero = T::zero();
    let n = inp.len();
    let two_pi = T::PI() + T::PI();
    let mut out = Vec::with_capacity(2 * n);
    let n_t = T::from(n).unwrap();
    for k in 0..n {
        let k_t = T::from(k).unwrap();
        let mut re = zero;
        let mut im = zero;
        for (j, &inp) in inp.iter().enumerate() {
            let j_t = T::from(j).unwrap();
            let angle = two_pi * k_t * j_t / n_t;
            re += inp * angle.cos();
            im -= inp * angle.sin();
        }
        out.push(re);
        out.push(im);
    }
    out
}
#[allow(clippy::too_many_arguments)]
fn log_mel_spectrogram_w<T: Float>(
    ith: usize,
    hann: &[T],
    samples: &[T],
    filters: &[T],
    fft_size: usize,
    fft_step: usize,
    speed_up: bool,
    n_len: usize,
    n_mel: usize,
    n_threads: usize,
) -> Vec<T> {
    let n_fft = if speed_up {
        1 + fft_size / 4
    } else {
        1 + fft_size / 2
    };
    let zero = T::zero();
    let half = T::from(0.5).unwrap();
    let mut fft_in = vec![zero; fft_size];
    let mut mel = vec![zero; n_len * n_mel];
    let n_samples = samples.len();
    let end = std::cmp::min(n_samples / fft_step + 1, n_len);
    for i in (ith..end).step_by(n_threads) {
        let offset = i * fft_step;
        for j in 0..std::cmp::min(fft_size, n_samples - offset) {
            fft_in[j] = hann[j] * samples[offset + j];
        }
        if n_samples - offset < fft_size {
            fft_in[n_samples - offset..].fill(zero);
        }
        let mut fft_out: Vec<T> = fft(&fft_in);
        for j in 0..fft_size {
            fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
        }
        for j in 1..fft_size / 2 {
            let v = fft_out[fft_size - j];
            fft_out[j] += v;
        }
        if speed_up {
            for j in 0..n_fft {
                fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
            }
        }
        for j in 0..n_mel {
            let mut sum = zero;
            let mut k = 0;
            while k < n_fft.saturating_sub(3) {
                sum += fft_out[k] * filters[j * n_fft + k]
                    + fft_out[k + 1] * filters[j * n_fft + k + 1]
                    + fft_out[k + 2] * filters[j * n_fft + k + 2]
                    + fft_out[k + 3] * filters[j * n_fft + k + 3];
                k += 4;
            }
            while k < n_fft {
                sum += fft_out[k] * filters[j * n_fft + k];
                k += 1;
            }
            mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
        }
    }
    mel
}
pub fn log_mel_spectrogram_<T: Float>(
    samples: &[T],
    filters: &[T],
    fft_size: usize,
    fft_step: usize,
    n_mel: usize,
    speed_up: bool,
) -> Vec<T> {
    let zero = T::zero();
    let two_pi = T::PI() + T::PI();
    let half = T::from(0.5).unwrap();
    let one = T::from(1.0).unwrap();
    let four = T::from(4.0).unwrap();
    let fft_size_t = T::from(fft_size).unwrap();
    let hann: Vec<T> = (0..fft_size)
        .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
        .collect();
    let n_len = samples.len() / fft_step;
    let pad = 100 * super::CHUNK_LENGTH / 2;
    let n_len = if n_len % pad != 0 {
        (n_len / pad + 1) * pad
    } else {
        n_len
    };
    let n_len = n_len + pad;
    let samples = {
        let mut samples_padded = samples.to_vec();
        let to_add = n_len * fft_step - samples.len();
        samples_padded.extend(std::iter::repeat(zero).take(to_add));
        samples_padded
    };
    let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
    let hann = Arc::new(hann);
    let samples = Arc::new(samples);
    let filters = Arc::new(filters);
    let all_outputs = thread::scope(|s| {
        (0..n_threads)
            .map(|thread_id| {
                let hann = Arc::clone(&hann);
                let samples = Arc::clone(&samples);
                let filters = Arc::clone(&filters);
                s.spawn(move || {
                    log_mel_spectrogram_w(
                        thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
                        n_mel, n_threads,
                    )
                })
            })
            .collect::<Vec<_>>()
            .into_iter()
            .map(|handle| handle.join().expect("Thread failed"))
            .collect::<Vec<_>>()
    });
    let l = all_outputs[0].len();
    let mut mel = vec![zero; l];
    for segment_start in (0..l).step_by(n_threads) {
        for thread_output in all_outputs.iter() {
            for offset in 0..n_threads {
                let mel_index = segment_start + offset; if mel_index < mel.len() {
                    mel[mel_index] += thread_output[mel_index];
                }
            }
        }
    }
    let mmax = mel
        .iter()
        .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
        .copied()
        .unwrap_or(zero)
        - T::from(8).unwrap();
    for m in mel.iter_mut() {
        let v = T::max(*m, mmax);
        *m = v / four + one
    }
    mel
}
pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
    log_mel_spectrogram_(
        samples,
        filters,
        super::N_FFT,
        super::HOP_LENGTH,
        cfg.num_mel_bins,
        false,
    )
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_fft() {
        let input = vec![0.0, 1.0, 0.0, 0.0];
        let output = fft(&input);
        assert_eq!(
            output,
            vec![
                1.0,
                0.0,
                6.123233995736766e-17,
                -1.0,
                -1.0,
                0.0,
                -6.123233995736766e-17,
                1.0
            ]
        );
    }
    #[test]
    fn test_dft() {
        let input = vec![0.0, 1.0, 0.0, 0.0];
        let output = dft(&input);
        assert_eq!(
            output,
            vec![
                1.0,
                0.0,
                6.123233995736766e-17,
                -1.0,
                -1.0,
                -1.2246467991473532e-16,
                -1.8369701987210297e-16,
                1.0
            ]
        );
    }
    #[test]
    fn test_log_mel_spectrogram() {
        let samples = vec![0.0; 1000];
        let filters = vec![0.0; 1000];
        let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
        assert_eq!(output.len(), 30_000);
    }
    #[test]
    fn test_tiny_log_mel_spectrogram() {
        let samples = vec![0.0; 100];
        let filters = vec![0.0; 100];
        let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
        assert_eq!(output.len(), 6_000);
    }
}