Source code for jesterTOV.utils

"""
Utilities
"""

from jax import vmap
import jax.numpy as jnp
from functools import partial
from jaxtyping import Array, Float
from interpax import interp1d as interpax_interp1d

from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController

#################################
### CONSTANTS AND CONVERSIONS ###
#################################

eV = 1.602176634e-19
c = 299792458.0
G = 6.6743e-11
Msun = 1.988409870698051e30
hbarc = 197.3269804593025  # in MeV fm
hbar = hbarc  # conventions from Rahul's code
m_p = 938.2720881604904  # in MeV
m_n = 939.5654205203889  # in MeV
m = (m_p + m_n) / 2.0  # in MeV, average nucleonic mass defined by Margueron et al
m_e = 0.510998  # mass electron in MeV
solar_mass_in_meter = Msun * G / c / c  # solar mass in geometric unit

# simple conversions
fm_to_m = 1e-15
MeV_to_J = 1e6 * eV
m_to_fm = 1.0 / fm_to_m
J_to_MeV = 1.0 / MeV_to_J

# number density
fm_inv3_to_SI = 1.0 / fm_to_m**3
number_density_to_geometric = 1
fm_inv3_to_geometric = fm_inv3_to_SI * number_density_to_geometric

SI_to_fm_inv3 = 1.0 / fm_inv3_to_SI
geometric_to_fm_inv3 = 1.0 / fm_inv3_to_geometric

# pressure and energy density
MeV_fm_inv3_to_SI = MeV_to_J * fm_inv3_to_SI
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI
pressure_SI_to_geometric = G / c**4
MeV_fm_inv3_to_geometric = MeV_fm_inv3_to_SI * pressure_SI_to_geometric
dyn_cm2_to_MeV_fm_inv3 = 1e-1 * J_to_MeV / m_to_fm**3
g_cm_inv3_to_MeV_fm_inv3 = 1e3 * c**2 * J_to_MeV / m_to_fm**3

geometric_to_SI = 1.0 / pressure_SI_to_geometric
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI
geometric_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_geometric


#########################
### UTILITY FUNCTIONS ###
#########################

# vmapped jnp.roots function
roots_vmap = vmap(partial(jnp.roots, strip_zeros=False), in_axes=0, out_axes=0)


[docs] @vmap def cubic_root_for_proton_fraction(coefficients): a, b, c, d = coefficients f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0 g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0 g_squared = g**2 f_cubed = f**3 h = g_squared / 4.0 + f_cubed / 27.0 R = -(g / 2.0) + jnp.sqrt(h) S = jnp.cbrt(R) T = -(g / 2.0) - jnp.sqrt(h) U = jnp.cbrt(T) x1 = (S + U) - (b / (3.0 * a)) x2 = -(S + U) / 2 - (b / (3.0 * a)) + (S - U) * jnp.sqrt(3.0) * 0.5j x3 = -(S + U) / 2 - (b / (3.0 * a)) - (S - U) * jnp.sqrt(3.0) * 0.5j return jnp.array([x1, x2, x3])
[docs] def cumtrapz(y, x): """ Cumulatively integrate y(x) using the composite trapezoidal rule. Parameters ---------- y : jax.numpy.ndarray Values to integrate. x : jax.numpy.ndarray The coordinate to integrate along. Returns ------- res : jax.numpy.ndarray The result of cumulative integration of `y` along `x`. """ # check the shape of y and x assert y.shape == x.shape, "Not matching shape between y and x" assert len(y.shape) == 1, "y is expected to be one-dimensional array" assert len(x.shape) == 1, "x is expected to be one-dimensional array" # get the step size of x dx = jnp.diff(x) res = jnp.cumsum(dx * (y[1::] + y[:-1:]) / 2.0) res = jnp.concatenate( ( jnp.array( [ 1e-30, ] ), res, ) ) return res
[docs] def interp_in_logspace(x, xs, ys): logx = jnp.log(x) logxs = jnp.log(xs) logys = jnp.log(ys) return jnp.exp(jnp.interp(logx, logxs, logys))
[docs] def limit_by_MTOV( pc: Array, m: Array, r: Array, l: Array ) -> tuple[Array, Array, Array, Array]: """ Limits the M, R and Lambda curves to be below MTOV in a jit-friendly manner (i.e., static shape sizes). The idea now is to feed this into some routine that creates an interpolation out of this, which then uses jnp.unique to get rid of these duplicates Args: pcs (Array["npoints"]): Original pressure m (Array["npoints"]): Original mass curve r (Array["npoints"]): Original radius curve l (Array["npoints"]): Original lambdas curve Returns: tuple[Array["npoints"], Array["npoints"], Array["npoints"]]: Tuple of new mass, radius and lambdas curves, where the part of the curves where mass decreases is replaced with duplication of the first entry of the M, R and Lambda arrays. """ # Fetch the MTOV, we will use it to dump duplicates of it wherever the NS family is unphysical m_at_TOV = jnp.max(m) idx_TOV = jnp.argmax(m) pc_at_TOV = pc[idx_TOV] r_at_TOV = r[idx_TOV] l_at_TOV = l[idx_TOV] # Find out where the mass array is increasing, and insert True at the TOV index to pad length of the array correctly m_is_increasing = jnp.diff(m) > 0 m_is_increasing = jnp.insert(m_is_increasing, idx_TOV, True) # All indices after MTOV index should be set to False m_is_increasing = jnp.where(jnp.arange(len(m)) > idx_TOV, False, m_is_increasing) pc_new = jnp.where(m_is_increasing, pc, pc_at_TOV) m_new = jnp.where(m_is_increasing, m, m_at_TOV) r_new = jnp.where(m_is_increasing, r, r_at_TOV) l_new = jnp.where(m_is_increasing, l, l_at_TOV) # Sort in increasing values of M for plotting etc sort_idx = jnp.argsort(m_new) pc_new = pc_new[sort_idx] m_new = m_new[sort_idx] r_new = r_new[sort_idx] l_new = l_new[sort_idx] return pc_new, m_new, r_new, l_new
################### ### SPLINES etc ### ###################
[docs] def cubic_spline(xq: Float[Array, "n"], xp: Float[Array, "n"], fp: Float[Array, "n"]): """ Create a cubic spline interpolating function through (xp, fp) with interpax (https://github.com/f0uriest/interpax) Args: xq (Float[Array, "n"]): x values at which we are going to evaluate the spline interpolator xp (Float[Array, "n"]): x values of the data points fp (Float[Array, "n"]): y values of the data points, i.e. fp = f(xp) """ return interpax_interp1d(xq, xp, fp, method="cubic")
[docs] def sigmoid(x: Array) -> Array: return 1.0 / (1.0 + jnp.exp(-x))
[docs] def calculate_rest_mass_density(e: Float[Array, "n"], p: Float[Array, "n"]): """ Compute rest-mass density given arrays of energy density and pressure. Parameters: - e (jax.numpy array): Array of specific energy values. - p (jax.numpy array): Array of pressure values. Returns: - rho (jax.numpy array): Array of density values. """ # Define a linear interpolation for p(e) def p_interp(e_val): return jnp.interp(e_val, e, p) # Define the ODE: drho/de = rho / (p + e) def rhs(t, rho, args): p_val = p_interp(t) return rho / (p_val + t) # Initial condition: rho[0] = e[0] rho0 = e[0] # Define the term for the ODE term = ODETerm(rhs) # Solve the ODE using diffrax solver = Tsit5() solution = diffeqsolve( term, solver, t0=e[0], # Initial value of e t1=e[-1], # Final value of e dt0=1e-8, # Initial step size y0=rho0, # Initial value of rho saveat=SaveAt(ts=e), stepsize_controller=PIDController(rtol=1e-5, atol=1e-6), ) # Return the rho values at specified e points return solution.ys