"""
Utilities
"""
from jax import vmap
import jax.numpy as jnp
from functools import partial
from jaxtyping import Array, Float
from interpax import interp1d as interpax_interp1d
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController
#################################
### CONSTANTS AND CONVERSIONS ###
#################################
eV = 1.602176634e-19
c = 299792458.0
G = 6.6743e-11
Msun = 1.988409870698051e30
hbarc = 197.3269804593025 # in MeV fm
hbar = hbarc # conventions from Rahul's code
m_p = 938.2720881604904 # in MeV
m_n = 939.5654205203889 # in MeV
m = (m_p + m_n) / 2.0 # in MeV, average nucleonic mass defined by Margueron et al
m_e = 0.510998 # mass electron in MeV
solar_mass_in_meter = Msun * G / c / c # solar mass in geometric unit
# simple conversions
fm_to_m = 1e-15
MeV_to_J = 1e6 * eV
m_to_fm = 1.0 / fm_to_m
J_to_MeV = 1.0 / MeV_to_J
# number density
fm_inv3_to_SI = 1.0 / fm_to_m**3
number_density_to_geometric = 1
fm_inv3_to_geometric = fm_inv3_to_SI * number_density_to_geometric
SI_to_fm_inv3 = 1.0 / fm_inv3_to_SI
geometric_to_fm_inv3 = 1.0 / fm_inv3_to_geometric
# pressure and energy density
MeV_fm_inv3_to_SI = MeV_to_J * fm_inv3_to_SI
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI
pressure_SI_to_geometric = G / c**4
MeV_fm_inv3_to_geometric = MeV_fm_inv3_to_SI * pressure_SI_to_geometric
dyn_cm2_to_MeV_fm_inv3 = 1e-1 * J_to_MeV / m_to_fm**3
g_cm_inv3_to_MeV_fm_inv3 = 1e3 * c**2 * J_to_MeV / m_to_fm**3
geometric_to_SI = 1.0 / pressure_SI_to_geometric
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI
geometric_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_geometric
#########################
### UTILITY FUNCTIONS ###
#########################
# vmapped jnp.roots function
roots_vmap = vmap(partial(jnp.roots, strip_zeros=False), in_axes=0, out_axes=0)
[docs]
@vmap
def cubic_root_for_proton_fraction(coefficients):
a, b, c, d = coefficients
f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0
g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0
g_squared = g**2
f_cubed = f**3
h = g_squared / 4.0 + f_cubed / 27.0
R = -(g / 2.0) + jnp.sqrt(h)
S = jnp.cbrt(R)
T = -(g / 2.0) - jnp.sqrt(h)
U = jnp.cbrt(T)
x1 = (S + U) - (b / (3.0 * a))
x2 = -(S + U) / 2 - (b / (3.0 * a)) + (S - U) * jnp.sqrt(3.0) * 0.5j
x3 = -(S + U) / 2 - (b / (3.0 * a)) - (S - U) * jnp.sqrt(3.0) * 0.5j
return jnp.array([x1, x2, x3])
[docs]
def cumtrapz(y, x):
"""
Cumulatively integrate y(x) using the composite trapezoidal rule.
Parameters
----------
y : jax.numpy.ndarray
Values to integrate.
x : jax.numpy.ndarray
The coordinate to integrate along.
Returns
-------
res : jax.numpy.ndarray
The result of cumulative integration of `y` along `x`.
"""
# check the shape of y and x
assert y.shape == x.shape, "Not matching shape between y and x"
assert len(y.shape) == 1, "y is expected to be one-dimensional array"
assert len(x.shape) == 1, "x is expected to be one-dimensional array"
# get the step size of x
dx = jnp.diff(x)
res = jnp.cumsum(dx * (y[1::] + y[:-1:]) / 2.0)
res = jnp.concatenate(
(
jnp.array(
[
1e-30,
]
),
res,
)
)
return res
[docs]
def interp_in_logspace(x, xs, ys):
logx = jnp.log(x)
logxs = jnp.log(xs)
logys = jnp.log(ys)
return jnp.exp(jnp.interp(logx, logxs, logys))
[docs]
def limit_by_MTOV(
pc: Array, m: Array, r: Array, l: Array
) -> tuple[Array, Array, Array, Array]:
"""
Limits the M, R and Lambda curves to be below MTOV in a jit-friendly manner (i.e., static shape sizes).
The idea now is to feed this into some routine that creates an interpolation out of this, which then uses jnp.unique to get rid of these duplicates
Args:
pcs (Array["npoints"]): Original pressure
m (Array["npoints"]): Original mass curve
r (Array["npoints"]): Original radius curve
l (Array["npoints"]): Original lambdas curve
Returns:
tuple[Array["npoints"], Array["npoints"], Array["npoints"]]: Tuple of new mass, radius and lambdas curves, where the part of the curves where mass decreases is replaced with duplication of the first entry of the M, R and Lambda arrays.
"""
# Fetch the MTOV, we will use it to dump duplicates of it wherever the NS family is unphysical
m_at_TOV = jnp.max(m)
idx_TOV = jnp.argmax(m)
pc_at_TOV = pc[idx_TOV]
r_at_TOV = r[idx_TOV]
l_at_TOV = l[idx_TOV]
# Find out where the mass array is increasing, and insert True at the TOV index to pad length of the array correctly
m_is_increasing = jnp.diff(m) > 0
m_is_increasing = jnp.insert(m_is_increasing, idx_TOV, True)
# All indices after MTOV index should be set to False
m_is_increasing = jnp.where(jnp.arange(len(m)) > idx_TOV, False, m_is_increasing)
pc_new = jnp.where(m_is_increasing, pc, pc_at_TOV)
m_new = jnp.where(m_is_increasing, m, m_at_TOV)
r_new = jnp.where(m_is_increasing, r, r_at_TOV)
l_new = jnp.where(m_is_increasing, l, l_at_TOV)
# Sort in increasing values of M for plotting etc
sort_idx = jnp.argsort(m_new)
pc_new = pc_new[sort_idx]
m_new = m_new[sort_idx]
r_new = r_new[sort_idx]
l_new = l_new[sort_idx]
return pc_new, m_new, r_new, l_new
###################
### SPLINES etc ###
###################
[docs]
def cubic_spline(xq: Float[Array, "n"], xp: Float[Array, "n"], fp: Float[Array, "n"]):
"""
Create a cubic spline interpolating function through (xp, fp) with interpax (https://github.com/f0uriest/interpax)
Args:
xq (Float[Array, "n"]): x values at which we are going to evaluate the spline interpolator
xp (Float[Array, "n"]): x values of the data points
fp (Float[Array, "n"]): y values of the data points, i.e. fp = f(xp)
"""
return interpax_interp1d(xq, xp, fp, method="cubic")
[docs]
def sigmoid(x: Array) -> Array:
return 1.0 / (1.0 + jnp.exp(-x))
[docs]
def calculate_rest_mass_density(e: Float[Array, "n"], p: Float[Array, "n"]):
"""
Compute rest-mass density given arrays of energy density and pressure.
Parameters:
- e (jax.numpy array): Array of specific energy values.
- p (jax.numpy array): Array of pressure values.
Returns:
- rho (jax.numpy array): Array of density values.
"""
# Define a linear interpolation for p(e)
def p_interp(e_val):
return jnp.interp(e_val, e, p)
# Define the ODE: drho/de = rho / (p + e)
def rhs(t, rho, args):
p_val = p_interp(t)
return rho / (p_val + t)
# Initial condition: rho[0] = e[0]
rho0 = e[0]
# Define the term for the ODE
term = ODETerm(rhs)
# Solve the ODE using diffrax
solver = Tsit5()
solution = diffeqsolve(
term,
solver,
t0=e[0], # Initial value of e
t1=e[-1], # Final value of e
dt0=1e-8, # Initial step size
y0=rho0, # Initial value of rho
saveat=SaveAt(ts=e),
stepsize_controller=PIDController(rtol=1e-5, atol=1e-6),
)
# Return the rho values at specified e points
return solution.ys