LDC Spritz

Contents

LDC Spritz#

Lets try to make a PSD for the LDC Spritz

import numpy as np
import jax

jax.config.update("jax_enable_x64", True)
np.random.seed(0)

f = "/Users/avaj0001/Downloads/drive-download-20250430T050734Z-1-001/LDC-2b-Spritz-MBHB1.asc"
data = np.loadtxt(f, skiprows=2).T
t, X, Y, Z = data


X = X[:535552]
time = t[:535552]
gap_mask = X == 0
gap_mask = gap_mask[:535552]


# standardise X
SCALING = np.std(X)
X = (X - np.mean(X)) / SCALING
import os


from pywavelet import set_backend
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import rfft, rfftfreq
from scipy.signal.windows import tukey
from copy import deepcopy

import glob
from PIL import Image

set_backend("jax")

from pywavelet.transforms import from_freq_to_wavelet
from pywavelet.transforms import from_wavelet_to_freq
from pywavelet.types import Wavelet, FrequencySeries, TimeSeries
import jax.numpy as jnp

plot_dir = "out_spritz"
os.makedirs(plot_dir, exist_ok=True)


def make_gif_from_images(image_regex, gif_name, duration=3):
    images = []
    files = glob.glob(image_regex)
    files = sorted(
        files,
        key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]),
    )
    for filename in files:
        images.append(Image.open(filename))
    images[0].save(
        gif_name,
        save_all=True,
        append_images=images[1:],
        duration=duration * 1000,  # Convert seconds to milliseconds
        loop=0,
    )
[12:06:52] WARNING  JAX SUBPACKAGE NOT FULLY TESTED                                                   __init__.py:5
           INFO     Jax running on cpu [32bit precision].                                            __init__.py:21
           WARNING  Jax is not running in 64bit precision. To change, use                            __init__.py:23
                    jax.config.update('jax_enable_x64', True).                                                     
def plot_wnm_and_ts(
    timeseries,
    freqseries,
    wnm,
    label="spritz",
    psd_approx=None,
    ax=None,
    save=True,
):
    if ax is None:
        fig, ax = plt.subplots(3, 1, figsize=(10, 6))

    wnm.plot(zscale="log", absolute=True, ax=ax[0])

    freqseries.plot_periodogram(ax=ax[1])
    if psd_approx is not None:
        psd_approx.plot(ax[1], scaling=1e5)

    ax[1].legend(frameon=False)

    timeseries.plot(ax=ax[2])

    fig = ax[1].get_figure()
    fig.suptitle(label)
    plt.tight_layout()
    if save:
        plt.savefig(os.path.join(plot_dir, f"{label}.png"))


def plot_mask(wnm, mask, label="mask"):
    mask_wnm = deepcopy(wnm)
    mask_wnm.data = mask_wnm.data.at[mask].set(1)  # set outliers to 1 (masked)
    mask_wnm.data = mask_wnm.data.at[~mask].set(
        0
    )  # set non-outliers to 0 (unmasked -- we keep)
    d = wnm.data.ravel()
    percentage = np.sum(mask) / len(d)
    print(f"Percentage of data masked: {percentage:.2%}")

    mask_wnm.plot(
        absolute=True,
        cmap="binary",
        show_colorbar=False,
        label=f"{percentage:.2%}% masked",
        zscale="linear",
    )
    plt.tight_layout()
    plt.savefig(os.path.join(plot_dir, f"{label}.png"))


def mad_threshold(wnm: Wavelet, nsigma: float = 5.0) -> Wavelet:
    """
    Sigma-clipping using median and MAD across time for each frequency bin.
    """
    print(f"MAD thresholding -- deleting values > nsigma ({nsigma:.2f})* MAD")

    wnm_copy = deepcopy(wnm)
    data = jnp.abs(wnm.data)
    # median and MAD
    median = jnp.median(data)
    mad = jnp.median(np.abs(data - median))

    # mask outliers (True = outlier)
    mask = np.abs(data - median) > nsigma * mad
    wnm_copy.data = wnm_copy.data.at[mask].set(median)
    return wnm_copy, mask


def generate_noise_from_psd(noise_psd, freqs, n_samples):
    white_noise_fft = (
        np.random.normal(size=len(freqs))
        + 1j * np.random.normal(size=len(freqs))
    ) / np.sqrt(2)
    white_noise_fft[0] /= np.sqrt(2)  # Adjust DC component
    # Scale amplitude by sqrt(n_samples) to match the PSD normalization with np.fft.irfft
    colored_noise_fft = white_noise_fft * np.sqrt(noise_psd * n_samples)
    noise_time = np.fft.irfft(colored_noise_fft, n=n_samples)
    return noise_time


def denoise(wnm, thr, gap_mask, i=0):

    wnm, mask = mad_threshold(wnm, nsigma=thr * 10)
    plot_mask(wnm, mask, label=f"mask-{i}")

    # plot each iteration
    fs = wnm.to_frequencyseries()
    ts = fs.to_timeseries()

    psd_approx = PSDApprox.fit(ts)

    fig, ax = plt.subplots(5, 1, figsize=(10, 10))
    plot_wnm_and_ts(
        ts,
        fs,
        wnm,
        label=f"Iter_{i + 1:02d}",
        psd_approx=psd_approx,
        ax=ax,
        save=False,
    )

    print(f"ts ND : {ts.ND}, wnm ND: {wnm.ND}")
    print("N Gaps:", np.sum(gap_mask))
    print("Filling gaps with generated noise")

    # fill in gap
    generated_noise = generate_noise_from_psd(
        psd_approx.power * 1e2, psd_approx.freq, ts.ND
    )
    generated_noise = generated_noise.reshape(-1, 1).T[0]
    generated_noise = generated_noise[: len(ts.data)]

    # fill in the gaps
    time_data = np.array(ts.data)

    print("BEFORE")
    print(np.std(generated_noise))
    # scale the generated_noise to match the original ts.data
    # generated_noise = generated_noise * np.std(ts.data[0:100]) + np.mean(ts.data[0:100])
    # print(np.std(generated_noise))
    time_data[gap_mask] = generated_noise[gap_mask] * 1e2

    new_ts = TimeSeries(time_data, time=ts.time)
    new_wnm = new_ts.to_frequencyseries().to_wavelet(Nf=256)

    ts.plot(ax[3])
    ax[3].plot(
        ts.time[gap_mask],
        generated_noise[gap_mask],
        color="tab:orange",
        label="generated noise",
    )
    # new_ts.plot(ax[3], color='tab:orange', alpha=0.5)
    new_wnm.plot(zscale="log", absolute=True, ax=ax[4])
    plt.savefig(
        os.path.join(
            plot_dir,
            f"Iter_{i + 1:02d}",
        )
    )

    return wnm


def iterative_denoise(
    wnm: Wavelet,
    gap_mask,
    iterations: int = 5,
    base_threshold: float = 0.3,
    decay: float = 0.9,
) -> Wavelet:
    """
    Iteratively apply thresholding, decaying the threshold each time.
    method: 'global', 'mad', or 'blockwise'
    """
    new_wnm = deepcopy(wnm)
    for i in range(iterations):
        thr = base_threshold * (decay**i)
        new_wnm = denoise(wnm, thr, gap_mask, i)

    return new_wnm


ts = TimeSeries(X[:535552], time=t[:535552])
print(f"ts ND : {ts.ND}")


fs = ts.to_frequencyseries()
orig_wnm = fs.to_wavelet(Nf=256)
plot_wnm_and_ts(ts, fs, orig_wnm, label="Iter_00")

wnm = fs.to_wavelet(Nf=256)
print(f"wnm ND : {wnm.ND}")
# psd_approx = PSDApprox.fit(ts)
#
# generated_noise = generate_noise_from_psd(psd_approx.power, psd_approx.freq, ts.ND)


wnm = iterative_denoise(
    wnm, gap_mask, iterations=1, base_threshold=0.9, decay=0.9
)
# make_gif_from_images(os.path.join(plot_dir, 'Iter_*.png'), os.path.join(plot_dir, 'wnm_psd_gen.gif'), duration=0.5)
ts ND : 535552
wnm ND : 535552
MAD thresholding -- deleting values > nsigma (9.00)* MAD
Percentage of data masked: 3.17%
ts ND : 535552, wnm ND: 535552
N Gaps: 10080
Filling gaps with generated noise
BEFORE
2.8030794625010573
../../_images/3b13fdfb0c4a40ea36f655ecb22dac236faaa18a837cfd37279b79e97341ec9d.png ../../_images/e944ecac8329e09d1b0f3d6f210521bf44eeea8e91495f35cf53d108db505b21.png ../../_images/a3f04f1d13426db3d111a164a6da002d83bf34292fef8d62fd83a7295af2f618.png
# gap_mask = ts.data == 0
# ts.m
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[43], line 2
      1 gap_mask = ts.data == 0
----> 2 ts.m

AttributeError: 'TimeSeries' object has no attribute 'm'
# i = 0
# wnm, mask = mad_threshold(wnm, nsigma=10)
#
# # plot each iteration
# fs = wnm.to_frequencyseries()
# ts = fs.to_timeseries()
#
# psd_approx = PSDApprox.fit(ts)
#


fig, ax = plt.subplots(5, 1, figsize=(10, 10))
plot_wnm_and_ts(
    ts,
    fs,
    wnm,
    label=f"Iter_{i + 1:02d}",
    psd_approx=psd_approx,
    ax=ax,
    save=False,
)


# fill in gap
generated_noise = generate_noise_from_psd(
    psd_approx.power * 1e1, psd_approx.freq, ts.ND
)
generated_noise = generated_noise.reshape(-1, 1).T[0]
generated_noise = generated_noise[: len(ts.data)]


# fill in the gaps
time_data = np.array(ts.data)

print("BEFORE")
print(np.std(generated_noise))
# scale the generated_noise to match the original ts.data
# generated_noise = generated_noise * np.std(ts.data[0:100]) + np.mean(ts.data[0:100])
# print(np.std(generated_noise))
time_data[gap_mask] = generated_noise[gap_mask] * 1e1


new_ts = TimeSeries(time_data, time=ts.time)
new_wnm = new_ts.to_frequencyseries().to_wavelet(Nf=256)

ts.plot(ax[3])
ax[3].plot(
    ts.time[gap_mask],
    generated_noise[gap_mask],
    color="tab:orange",
    label="generated noise",
)
# new_ts.plot(ax[3], color='tab:orange', alpha=0.5)
new_wnm.plot(zscale="log", absolute=True, ax=ax[4])
plt.savefig(
    os.path.join(
        plot_dir,
        f"Iter_{i + 1:02d}",
    )
)
BEFORE
0.8844361892657235
../../_images/45a60273bec45299b43090f8e27dbc6022ff0b05fd20fd993516f27bcbfdeaa7.png
ts.time[gap_mask]
array([], dtype=float64)
np.std(ts.data)
np.float64(1.2227757197604746e-20)
generated_noise = generate_noise_from_psd(
    psd_approx.power * 1e5, psd_approx.freq, 2048
)
plt.plot(generated_noise * 1e-18)
[<matplotlib.lines.Line2D at 0x1225e47c0>]
../../_images/15ecc186cd09d59028671a3ad3c16b3e65e8c468ac922afd650c446795d70c55.png
plt.plot(ts.time, np.array(generated_noise * 1e-20))
# plt.scatter(ts.time[gap_mask], generated_noise[gap_mask], color="tab:orange", zorder=10)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[54], line 1
----> 1 plt.plot(ts.time, np.array(generated_noise * 1e-20))
      2 plt.scatter(ts.time[gap_mask], generated_noise[gap_mask], color="tab:orange", zorder=10)

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/pyplot.py:3827, in plot(scalex, scaley, data, *args, **kwargs)
   3819 @_copy_docstring_and_deprecators(Axes.plot)
   3820 def plot(
   3821     *args: float | ArrayLike | str,
   (...)
   3825     **kwargs,
   3826 ) -> list[Line2D]:
-> 3827     return gca().plot(
   3828         *args,
   3829         scalex=scalex,
   3830         scaley=scaley,
   3831         **({"data": data} if data is not None else {}),
   3832         **kwargs,
   3833     )

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_axes.py:1777, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
   1534 """
   1535 Plot y versus x as lines and/or markers.
   1536 
   (...)
   1774 (``'green'``) or hex strings (``'#008000'``).
   1775 """
   1776 kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
-> 1777 lines = [*self._get_lines(self, *args, data=data, **kwargs)]
   1778 for line in lines:
   1779     self.add_line(line)

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_base.py:297, in _process_plot_var_args.__call__(self, axes, data, return_kwargs, *args, **kwargs)
    295     this += args[0],
    296     args = args[1:]
--> 297 yield from self._plot_args(
    298     axes, this, kwargs, ambiguous_fmt_datakey=ambiguous_fmt_datakey,
    299     return_kwargs=return_kwargs
    300 )

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_base.py:494, in _process_plot_var_args._plot_args(self, axes, tup, kwargs, return_kwargs, ambiguous_fmt_datakey)
    491     axes.yaxis.update_units(y)
    493 if x.shape[0] != y.shape[0]:
--> 494     raise ValueError(f"x and y must have same first dimension, but "
    495                      f"have shapes {x.shape} and {y.shape}")
    496 if x.ndim > 2 or y.ndim > 2:
    497     raise ValueError(f"x and y can be no greater than 2D, but have "
    498                      f"shapes {x.shape} and {y.shape}")

ValueError: x and y must have same first dimension, but have shapes (535552,) and (2048,)
../../_images/33591efa9564a66f129f6ec32a0f5b5c19f44e60783ea04c1d8c7ee72fba24dd.png

Fitting PSD#

import jax
import jax.numpy as jnp
import optax
from typing import Tuple


def fit_psd_spline_jax(
    freqs: jnp.ndarray,  # shape [N]
    psd_obs: jnp.ndarray,  # shape [N]
    n_knots: int = 20,
    lr: float = 1e-3,
    steps: int = 200,
    log_eps: float = 1e-12,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    knots = jnp.linspace(freqs[0], freqs[-1], n_knots)
    log_true = jnp.log(psd_obs + log_eps)
    init_knot_logs = jnp.interp(knots, freqs, log_true)

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(init_knot_logs)

    def loss_and_grads(knot_logs):
        pred = jnp.interp(freqs, knots, knot_logs)
        loss = jnp.mean((pred - log_true) ** 2)
        grads = jax.grad(
            lambda kl: jnp.mean((jnp.interp(freqs, knots, kl) - log_true) ** 2)
        )(knot_logs)
        return loss, grads

    def step(i, state):
        knot_logs, opt_state = state
        loss, grads = loss_and_grads(knot_logs)
        updates, opt_state = optimizer.update(grads, opt_state)
        knot_logs = optax.apply_updates(knot_logs, updates)
        return (knot_logs, opt_state)

    # training loop
    knot_logs_final, _ = jax.jit(
        lambda kl, os: jax.lax.fori_loop(0, steps, step, (kl, os))
    )(init_knot_logs, opt_state)

    log_psd_fit = jnp.interp(freqs, knots, knot_logs_final)
    return jnp.exp(log_psd_fit), knot_logs_final, knots


std = np.std(ts.data)
mean = np.mean(ts.data)
ynorm = (ts.data - mean) / std
sampling_freq = float(1 / (ts.time[1] - ts.time[0]))
freq = np.fft.rfftfreq(len(ynorm), d=1 / sampling_freq)
power = np.abs(jnp.fft.rfft(ynorm)) ** 2 / len(ynorm)

freq = freq[1::5]
power = power[1::5]

psd_fit, fitted_logs, knots = fit_psd_spline_jax(freq, power)

# then you can plot
import matplotlib.pyplot as plt

plt.loglog(freq, power, label="data")
plt.loglog(freq, psd_fit, label="spline fit")
plt.scatter(knots, jnp.exp(fitted_logs), color="C1", label="knots")
plt.legend()
plt.show()
../../_images/498a06b13f9a6b0837604a84b0ec513a91d7a7ea1b68303980b62e0b64491460.png
import jax
import jax.numpy as jnp
from jax import lax

import jax
import jax.numpy as jnp
from jax import lax


from dataclasses import dataclass


@dataclass
class PSDApprox:
    freq: jnp.array
    power: jnp.array

    @classmethod
    def fit(cls, ts, window_size=101):
        """
        Calculates a running median approximation of the Power Spectral Density (PSD)
        from a periodogram.

        Args:
          periodogram: A JAX array representing the periodogram.
          window_size: An integer representing the size of the running median window.
                       Must be odd.

        Returns:
          A JAX array representing the running median PSD approximation.
        """

        std = np.std(ts.data)
        mean = np.mean(ts.data)
        ynorm = (ts.data - mean) / std
        sampling_freq = float(1 / (ts.time[1] - ts.time[0]))
        freq = np.fft.rfftfreq(len(ynorm), d=1 / sampling_freq)
        power = np.abs(jnp.fft.rfft(ynorm)) ** 2 / len(ynorm)

        freq = freq[1::5]
        periodogram = power[1::5]
        periodogram = jnp.array(periodogram, dtype=jnp.float64)

        if window_size % 2 == 0:
            raise ValueError("Window size must be odd for running median.")

        padding = window_size // 2
        padded_periodogram = jnp.pad(
            periodogram, (padding, padding), mode="reflect"
        )

        def median_window(i):
            window = jax.lax.dynamic_slice(
                padded_periodogram, (i,), (window_size,)
            )
            return jnp.median(jnp.sort(window))

        n = periodogram.shape[0]
        indices = jnp.arange(n)
        running_median = jax.vmap(median_window)(indices)

        # rescale the running median to have the same power as the original
        running_median = running_median

        return cls(freq=freq, power=running_median * std**2)

    def plot(self, ax, scaling=1):
        p = np.array(self.power) * scaling
        ax.loglog(self.freq, p, label="running median approx", linestyle="--")


psd_approx = PSDApprox.fit(ts)

fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.loglog(freq, power, label="Data")
psd_approx.plot(ax)
psd_approx.plot(ax, scaling=1e3)
../../_images/3dec954e2f1e4ff6e010e717acdae61b412915042cb2b3246223d6e3e3e92f22.png
psd_approx = running_median_psd_approximation(fs.periodogram, window_size=101)


plt.loglog(freq, power, label="spline fit")
plt.plot(freq, psd_approx, label="running median approx", linestyle="--")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 5
      1 psd_approx = running_median_psd_approximation(fs.periodogram, window_size=101)
      4 plt.loglog(freq, power, label="spline fit")
----> 5 plt.plot(freq, psd_approx, label="running median approx", linestyle='--')

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/pyplot.py:3827, in plot(scalex, scaley, data, *args, **kwargs)
   3819 @_copy_docstring_and_deprecators(Axes.plot)
   3820 def plot(
   3821     *args: float | ArrayLike | str,
   (...)
   3825     **kwargs,
   3826 ) -> list[Line2D]:
-> 3827     return gca().plot(
   3828         *args,
   3829         scalex=scalex,
   3830         scaley=scaley,
   3831         **({"data": data} if data is not None else {}),
   3832         **kwargs,
   3833     )

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_axes.py:1777, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
   1534 """
   1535 Plot y versus x as lines and/or markers.
   1536 
   (...)
   1774 (``'green'``) or hex strings (``'#008000'``).
   1775 """
   1776 kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
-> 1777 lines = [*self._get_lines(self, *args, data=data, **kwargs)]
   1778 for line in lines:
   1779     self.add_line(line)

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_base.py:297, in _process_plot_var_args.__call__(self, axes, data, return_kwargs, *args, **kwargs)
    295     this += args[0],
    296     args = args[1:]
--> 297 yield from self._plot_args(
    298     axes, this, kwargs, ambiguous_fmt_datakey=ambiguous_fmt_datakey,
    299     return_kwargs=return_kwargs
    300 )

File ~/miniforge3/envs/pywavelet/lib/python3.10/site-packages/matplotlib/axes/_base.py:494, in _process_plot_var_args._plot_args(self, axes, tup, kwargs, return_kwargs, ambiguous_fmt_datakey)
    491     axes.yaxis.update_units(y)
    493 if x.shape[0] != y.shape[0]:
--> 494     raise ValueError(f"x and y must have same first dimension, but "
    495                      f"have shapes {x.shape} and {y.shape}")
    496 if x.ndim > 2 or y.ndim > 2:
    497     raise ValueError(f"x and y can be no greater than 2D, but have "
    498                      f"shapes {x.shape} and {y.shape}")

ValueError: x and y must have same first dimension, but have shapes (53568,) and (267841,)
../../_images/2c5538679dccace4d1d0bb72075ede8fc7fd34aa8bb3d7d0e8c402c88902bfbe.png
psd_approx = PSDApprox.fit(ts)

fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.loglog(freq, power, label="Data")
psd_approx.plot(ax)
../../_images/198f1e28ca9e90af4ca17ea8eb6b4e9af358f490738e149a53baaa72cbf9f32c.png
%reload_ext autoreload
%autoreload 2

from log_psplines.psplines import LogPSplines, Periodogram
from log_psplines.datasets import Timeseries as PsplinesTimeseries
from log_psplines.plotting import plot_pdgrm
from log_psplines.mcmc import run_mcmc


def get_basic_psd(ts):
    psd_approx = PSDApprox.fit(ts)

    std = np.std(ts.data)
    mean = np.mean(ts.data)
    ynorm = (ts.data - mean) / std
    new_ts = PsplinesTimeseries(t=ts.time, y=ynorm, std=std)
    periodogram = new_ts.to_periodogram()
    spline_model = LogPSplines.from_periodogram(
        periodogram=periodogram,
        n_knots=30,
        degree=3,
        diffMatrixOrder=2,
        parametric_model=psd_approx.power,
    )
    psd_fit = spline_model() * std**2
    freq = periodogram.freqs
    return freq, psd_fit


#
# freq, psd = get_basic_psd(ts)
# fig, ax = ts.to_frequencyseries().plot_periodogram()
# ax.plot(freq, psd, label='LogPSpline fit', color='red')

#
# # new_ts = new_ts.standardise()
# periodogram = new_ts.to_periodogram()
#
# plot_pdgrm(
#     pdgrm=periodogram
# )
# new_ts.fs

# print(new_ts.std)
#
# plt.plot(new_ts.t, new_ts.y)

psd_approx = PSDApprox.fit(ts)
pgrm = Periodogram(freqs=freq[1::5], power=power[1::5])
approx_pdgrm = Periodogram(
    freqs=freq[1::5], power=(psd_approx.power + 1e-6)[1::5]
)

# periodogram = new_ts.to_periodogram()
# periodogram = periodogram.downsample(factor=2)
spline_model = LogPSplines.from_periodogram(
    periodogram=approx_pdgrm,
    n_knots=40,
    degree=3,
    diffMatrixOrder=2,
    knot_kwargs=dict(frac_log=0.2),
    parametric_model=psd_approx.power[1::5],
)
# psd_fit = spline_model()*std**2
# freq = periodogram.freqs


#
# plot_pdgrm(approx_pdgrm,
#            spline_model=spline_model,)
# spline_model.weights = jnp.zeros_like(spline_model.weights)


fig, ax = plot_pdgrm(
    pgrm,
    spline_model=spline_model,
)
(<Figure size 400x300 with 1 Axes>,
 <Axes: xlabel='Frequency (Hz)', ylabel='Power'>)
../../_images/2c92e04ec0c3aab10de8df43b12832e352d1155577ac3004cfb81876c520e5a6.png
spline_model.weights
Array([-1.83805977e-04, -8.91185326e-04, -3.50013510e-03, -1.41638674e-02,
       -4.68858779e-02, -5.48335022e-02, -8.07325358e-02, -4.11451571e-02,
       -2.90809579e-02,  1.35925098e-02,  1.89796170e-03, -1.22993804e-01,
       -4.31864295e-02, -6.86834892e-02, -2.76254263e-02, -5.92031779e-02,
       -3.12930658e-02, -1.45095947e-01, -1.69640967e-01, -9.97554737e-03,
       -3.66073730e-02,  8.64408365e-04,  4.08797426e-03, -1.64652760e-02,
       -5.38484879e-03, -9.39892661e-02, -4.52676155e-02, -3.97503113e-01,
       -1.78027318e-01, -5.50896853e-02, -2.90669526e-02,  1.80494641e-02,
        2.39938545e-02,  3.91933908e-02,  3.06856455e-02,  2.16521871e-02,
       -3.09561810e-04, -1.92035073e-01,  6.27716268e-02, -1.60554802e-02,
        3.00683515e-02,  3.03128060e-03], dtype=float64)
import time

# approx_pdgrm = approx_pdgrm.downsample(10)

t0 = time.time()
samples, spline_model = run_mcmc(
    pgrm,
    parametric_model=psd_approx.power[1::5],
    n_knots=40,
    num_samples=500,
    num_warmup=500,
)


runtime = float(time.time()) - t0


fig, ax = plot_pdgrm(pgrm, spline_model, sampler.get_samples()["weights"])
fig.savefig(os.path.join(plot_dir, f"test_mcmc.png"))
Spline model: LogPSplines(knots=40, degree=3, n=10711)
sample: 100%|██████████| 1000/1000 [05:59<00:00,  2.78it/s, 1023 steps of size 5.86e-04. acc. prob=0.93]
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[87], line 18
      6 samples, spline_model = run_mcmc(
      7     pgrm, 
      8     parametric_model=psd_approx.power[1::5],
      9     n_knots=40, num_samples=500, num_warmup=500,
     10 )
     14 runtime = float(time.time()) - t0
---> 18 fig, ax = plot_pdgrm(pgrm, spline_model, sampler.get_samples()['weights'])
     19 fig.savefig(os.path.join(plot_dir, f"test_mcmc.png"))

NameError: name 'sampler' is not defined
fig, ax = plot_pdgrm(pgrm, spline_model, samples.get_samples()["weights"])
fig.savefig(os.path.join(plot_dir, f"test_mcmc.png"))
../../_images/8d5533ed10e224a327870710bf53a3faa1a28f3f69641ecebdc2e2ee6ceab6a0.png
fig, ax = plot_pdgrm(pgrm, spline_model, samples.get_samples()["weights"])
ax.set_xlim(left=10e-5)
# fig.savefig(os.path.join(plot_dir, f"test_mcmc.png"))
(0.0001, np.float64(0.0999925310707457))
../../_images/54037f264781632b97537aeebcef0c40d29c4a80352bf7e542e440e32478631b.png
new_ts = PsplinesTimeseries(t=ts.time, y=ts.data)
new_ts.standardise()
periodogram = new_ts.to_periodogram()
import jax
import jax.numpy as jnp
import optax

logP = jnp.log(power + 1e-6)  # log-PSD (add eps to avoid log(0))

# Define knots for piecewise-linear fit on [0,0.5]
num_knots = 40
knots = jnp.linspace(freq[0], freq[-1], num_knots)
# Initialize spline coefficients (values at knots)
coefs = jnp.zeros(num_knots)

# Hyperparameters: smoothing penalty weight
lambda_penalty = 1e-2


def loss_fn(coefs, freq, logP):
    # Linear interpolation of coefs at frequencies
    y_pred = jnp.interp(freq, knots, coefs)
    # Mean squared error on log-PSD
    mse = jnp.mean((y_pred - logP) ** 2)
    # Second-difference penalty
    pen = lambda_penalty * jnp.sum(jnp.diff(coefs, 2) ** 2)
    return mse + pen


# Use Optax to minimize the loss
optimizer = optax.adam(learning_rate=0.1)
opt_state = optimizer.init(coefs)


@jax.jit
def train_step(coefs, opt_state, freq, logP):
    loss, grads = jax.value_and_grad(loss_fn)(coefs, freq, logP)
    updates, opt_state = optimizer.update(grads, opt_state)
    coefs = optax.apply_updates(coefs, updates)
    return coefs, opt_state, loss


# Run a few training iterations
for i in range(500):
    coefs, opt_state, loss = train_step(coefs, opt_state, freq, logP)
# The fitted smooth log-PSD is:
logP_smooth_spline = jnp.interp(freq, knots, coefs)
plt.loglog(freq, power)
plt.loglog(freq, jnp.exp(logP_smooth_spline), label="spline fit")
[<matplotlib.lines.Line2D at 0x12d302620>]
../../_images/b60946182971ed30bfd9daca454b54f7339d8f7eeac4b60b756164801dc5e93c.png
gap_mask
array([False, False, False, ..., False, False, False])