Lisa Sprint

Contents

Open In Colab

Lisa Sprint#

WDM transform in JAX, trying to follow Neil’s paper as closely as possible.

import numpy as np
from scipy.special import betainc
from scipy.signal import chirp
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
from matplotlib.colors import TwoSlopeNorm
import glob
import os
from PIL import Image

jax.config.update("jax_enable_x64", True)

DEVICE = jax.devices()[0].device_kind
print(f"Running JAX on {DEVICE}")

PI = np.pi
FREQ_RANGE = [20, 100]


def omega(Nf: int, Nt: int):
    df = 2 * np.pi / (Nf * Nt)
    return df * np.arange(-Nt // 2, Nt // 2, dtype=np.float64)


def phitilde_vec_norm(Nf: int, Nt: int, d: float):
    return _phitilde_vec(omega(Nf, Nt), Nf, d) * np.sqrt(np.pi)


def phi_vec(Nf: int, d: float = 4.0, q: int = 16):
    """get time domain phi as fourier transform of _phitilde_vec
    q: number of Nf bins over which the window extends?
    """
    insDOM = 1.0 / np.sqrt(np.pi / Nf)
    K = q * 2 * Nf
    half_K = q * Nf  # xp.int64(K/2)

    dom = 2 * np.pi / K  # max frequency is K/2*dom = pi/dt = OM
    DX = np.zeros(K, dtype=np.complex128)

    # zero frequency
    DX[0] = insDOM

    DX = DX.copy()
    # postive frequencies
    DX[1 : half_K + 1] = _phitilde_vec(dom * np.arange(1, half_K + 1), Nf, d)
    # negative frequencies
    DX[half_K + 1 :] = _phitilde_vec(
        -dom * np.arange(half_K - 1, 0, -1), Nf, d
    )
    DX = K * np.fft.ifft(DX, K)

    phi = np.zeros(K)
    phi[0:half_K] = np.real(DX[half_K:K])
    phi[half_K:] = np.real(DX[0:half_K])

    nrm = np.sqrt(2.0) / np.sqrt(K / dom)  # *xp.linalg.norm(phi)

    phi *= nrm
    return np.array(phi)


def _phitilde_vec(omega, Nf: int, d: float = 4.0):
    """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
    Eq 11 of Cornish '20
    """
    dF = 1.0 / (2 * Nf)  # NOTE: missing 1/dt?
    dOmega = 2 * np.pi * dF  # Near Eq 10 # 2 pi times DF
    inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)

    A = dOmega / 4
    B = dOmega - 2 * A  # Cannot have B \leq 0.
    if B <= 0:
        raise ValueError("B must be greater than 0")

    phi = np.zeros(omega.size, dtype=np.float64)
    mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B)
    vd = (np.pi / 2.0) * betainc(d, d, (np.abs(omega)[mask] - A) / B)
    phi[mask] = inverse_sqrt_dOmega * np.cos(vd)
    phi[np.abs(omega) < A] = inverse_sqrt_dOmega
    return phi


def xtilde_i(
    m: int, data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
) -> jnp.ndarray:
    # i0 = data.shape[0] // 2 + 1
    # i0 is the center of the data -- used
    i0 = (Nt * Nf) // 2 + 1
    j0 = i0 - (Nt // 2)
    j1 = i0 + (Nt // 2)
    xni = jnp.fft.ifft(jnp.roll(data, -(m * Nt // 2))[j0:j1] * phif)
    return xni


def Cnm_matrix(Nf: int, Nt: int) -> jnp.ndarray:
    i_inds = jnp.arange(Nf)
    j_inds = jnp.arange(Nt)

    ij = i_inds[:, None] + j_inds[None, :]
    ij_prod = i_inds[:, None] * j_inds[None, :]

    sign_matrix = jnp.where(ij_prod % 2 == 0, 1, -1)

    return jnp.where(ij % 2 == 0, 1, 1j) * sign_matrix


def freq_to_wdm(
    data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
) -> jnp.ndarray:
    # fftshifft data to match up with Neil Eq 17
    data_shifted = jnp.fft.fftshift(data)
    output = jnp.zeros((Nf, Nt), dtype=jnp.complex128)

    # v1 = [xtilde_i(i, data_shifted, Nf, Nt, phif) for i in range(Nt)]

    return (
        jax.lax.fori_loop(
            0,
            Nt,
            lambda i, output: output.at[:, i].set(
                xtilde_i(i, data_shifted, Nf, Nt, phif)
            ),
            output,
        )
        * Cnm_matrix(Nf, Nt)
    ).real * jnp.sqrt(2.0)


Nf = 32
Nt = Nf
nx = 4.0

assert Nf & (Nf - 1) == 0, "Nf must be a power of 2"
fs = 256
dt = 1 / fs
mult = 16
nx = 1
ND = Nt * Nf
t = np.arange(0, ND) * dt
f = np.fft.rfftfreq(ND, dt)
T = ND * dt
y = chirp(t, f0=FREQ_RANGE[0], f1=FREQ_RANGE[1], t1=t[-1], method="quadratic")
phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
phit = phi_vec(Nf, d=nx)
yf = jnp.fft.fft(y)

# wave = transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)
yf = jnp.array(yf)
jax_phif = jnp.array(phif)

plt.plot(phitilde_vec_norm(Nf, Nt, d=1), label="d=1")
plt.plot(phitilde_vec_norm(Nf, Nt, d=4), label="d=4")
plt.plot(phitilde_vec_norm(Nf, Nt, d=100), label="d=100")
plt.legend()
plt.show()

wave = freq_to_wdm(yf, Nf, Nt, jax_phif)
delta_T = T / Nt
delta_F = 1 / (2 * delta_T)
t_bins = np.arange(0, Nt) * delta_T
f_bins = np.arange(0, Nf) * delta_F

fig, axes = plt.subplots(2, 1, figsize=(4, 6))
axes[0].semilogy(f, (np.abs(yf) ** 2)[: ND // 2 + 1], label="FFT")
axes[0].set_xlabel("Frequency (Hz)", fontsize=14)
axes[0].set_ylabel("Power", fontsize=14)
axes[0].set_xlim(10, 100)
axes[0].set_ylim(bottom=10)

im = axes[1].imshow(
    np.abs((wave.T)),
    aspect="auto",
    extent=[t_bins[0], t_bins[-1], f_bins[0], f_bins[-1]],
    origin="lower",
    interpolation="nearest",
)
axes[1].set_xlabel("Time (s)", fontsize=14)
axes[1].set_ylabel("Frequency (Hz)", fontsize=14)
plt.tight_layout()

fig.subplots_adjust(hspace=0.5)
plt.show()


#
#
# def gnm_matrix(omega, Nf, Nt, A, B, d):
#     """
#     g˜nm(ω) = e −inω∆T (CnmΦ(ω − m∆Ω)+C∗nmΦ(ω + m∆Ω))
#
#
#     2A+B = ∆Ω
#     """
#
#
#     cnm = Cnm_matrix(Nf, Nt)
#     cnmconj = jnp.conjugate(cnm)
#
#
#
#     v1 = jnp.exp(-1j * omega[:, None] * jnp.arange(Nt) / Nf)
#
#     phi =
#
#
#     return  v1 * cnm * ph +

#
# def gnm(Cnm, CnmConj, A, B, d):  ### Eq 10
#     """
#     gnm(ω) = e −inω∆T (CnmΦ(ω − m∆Ω)+C∗nmΦ(ω + m∆Ω))
#     2A+B = ∆Ω
#     """
#
#     # omega = jnp.fft.fftshift(omega)
#     # omega = jnp.roll(omega, -Nf // 2)
#     # omega = jnp.fft.ifftshift(omega)
#
#      _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
#
#     return v1 * Cnm * ph + v1 * CnmConj * ph


#
# def


# wave2 = freq_to_wdm(yf, Nf, Nt, jax_phif)
Running JAX on cpu
../../_images/2a6d44b3f89ae61cc4f4c70f3088f9ac5bc3a7491e7c38ac68453d45b1e3d4a0.png ../../_images/38096b59e006140efac8f96723c214cea73e306d8f030f933da945a76605498d.png

Animation#

import os

plt_dir = "wdm_plots"
# make
os.makedirs(plt_dir, exist_ok=True)

f = np.fft.fftfreq(ND, dt)
f = np.fft.fftshift(f)
d = jnp.fft.fftshift(yf)
output = jnp.zeros((Nf, Nt), dtype=jnp.complex128)
i0 = (Nt * Nf) // 2 + 1
j0 = i0 - (Nt // 2)
j1 = i0 + (Nt // 2)
min_wnm, max_wnm = np.min(wave), np.max(wave)

for m in range(Nt):
    rolled_d = jnp.roll(d, -(m * Nt // 2))
    xni = jnp.fft.ifft(rolled_d[j0:j1] * phif)

    fig = plt.figure(figsize=(5, 6))
    gs = fig.add_gridspec(2, 2)

    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, :])

    ax1.plot(f, np.abs(d) ** 2, color="tab:gray", alpha=0.2)
    ax1.plot(f, np.abs(rolled_d) ** 2, color="tab:gray", label="Roll(y)")
    ax1.fill_between(
        np.arange(-Nt // 2, Nt / 2),
        max(np.abs(d) ** 2) * phif / max(phif),
        alpha=0.5,
        color="tab:orange",
        label="Window",
    )
    ax1.legend(loc="upper left", frameon=False)
    ax1.set_ylim(bottom=0)
    ax2.plot(t_bins, xni)
    ax2.set_ylim(min_wnm, max_wnm)

    ax1.set_xlabel("Freq [Hz]")
    ax1.set_ylabel("Power")

    ax2.set_xlabel("Time [s]")
    ax2.set_ylabel("Wnm")

    wdm_m = np.array(wave)
    wdm_m[:, m + 1 :] = 0
    # cmap of 'bwr' with vcenter=0

    curr_min = min(wdm_m.min(), -1)
    curr_max = max(wdm_m.max(), 1)

    norm = TwoSlopeNorm(0, curr_min, curr_max)
    im = ax3.imshow(
        wdm_m.T,
        aspect="auto",
        extent=[t_bins[0], t_bins[-1], f_bins[0], f_bins[-1]],
        origin="lower",
        interpolation="nearest",
        norm=norm,
        cmap="bwr",
    )
    ax3.set_xlabel("Time [s]", fontsize=14)
    ax3.set_ylabel("Freq [Hz]", fontsize=14)
    # add colorbar and label it Wnm
    cbar = fig.colorbar(im, ax=ax3)
    cbar.set_label("Wnm", fontsize=14)
    plt.tight_layout()
    plt.savefig(f"{plt_dir}/wdm_{m:02d}.png")


def make_gif_from_images(image_dir, gif_name, duration=3):
    images = []
    files = glob.glob(os.path.join(image_dir, "*.png"))
    files = sorted(
        files,
        key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]),
    )
    for filename in files:
        images.append(Image.open(filename))
    images[0].save(
        gif_name,
        save_all=True,
        append_images=images[1:],
        duration=duration * 1000,  # Convert seconds to milliseconds
        loop=0,
    )


make_gif_from_images(plt_dir, "wdm.gif", duration=0.5)