Source code for jesterTOV.eos

import os
import jax
import jax.numpy as jnp
from jax.scipy.special import factorial
from jaxtyping import Array, Float, Int

from . import utils, tov, ptov

##############
### CRUSTS ###
##############

DEFAULT_DIR = os.path.join(os.path.dirname(__file__))
CRUST_DIR = f"{DEFAULT_DIR}/crust"


[docs] def load_crust(name: str) -> tuple[Array, Array, Array]: """ Load a crust file from the default directory. Args: name (str): Name of the crust to load, or a filename if a file outside of jose is supplied. Returns: tuple[Array, Array, Array]: Number densities [fm^-3], pressures [MeV / fm^-3], and energy densities [MeV / fm^-3] of the crust. """ # Get the available crust names available_crust_names = [ f.split(".")[0] for f in os.listdir(CRUST_DIR) if f.endswith(".npz") ] # If a name is given, but it is not a filename, load the crust from the jose directory if not name.endswith(".npz"): if name in available_crust_names: name = os.path.join(CRUST_DIR, f"{name}.npz") else: raise ValueError( f"Crust {name} not found in {CRUST_DIR}. Available crusts are {available_crust_names}" ) # Once the correct file is identified, load it crust = jnp.load(name) n, p, e = crust["n"], crust["p"], crust["e"] return n, p, e
[docs] class Interpolate_EOS_model(object): """ Base class to interpolate EOS data. """ def __init__(self): pass
[docs] def interpolate_eos( self, n: Float[Array, "n_points"], p: Float[Array, "n_points"], e: Float[Array, "n_points"], ): """ Given n, p and e, interpolate to obtain necessary auxiliary quantities. Args: n (Float[Array, n_points]): Number densities. Expected units are n[fm^-3] p (Float[Array, n_points]): Pressure values. Expected units are p[MeV / fm^3] e (Float[Array, n_points]): Energy densities. Expected units are e[MeV / fm^3] Returns: tuple: Interpolated values of n, p, hs (enthalpy) e, and dloge_dlogps. """ # Save the provided data as attributes, make conversions ns = jnp.array(n * utils.fm_inv3_to_geometric) ps = jnp.array(p * utils.MeV_fm_inv3_to_geometric) es = jnp.array(e * utils.MeV_fm_inv3_to_geometric) # rhos = utils.calculate_rest_mass_density(es, ps) hs = utils.cumtrapz(ps / (es + ps), jnp.log(ps)) # enthalpy dloge_dlogps = jnp.diff(jnp.log(e)) / jnp.diff(jnp.log(p)) dloge_dlogps = jnp.concatenate( ( jnp.array( [ dloge_dlogps.at[0].get(), ] ), dloge_dlogps, ) ) return ns, ps, hs, es, dloge_dlogps
[docs] class MetaModel_EOS_model(Interpolate_EOS_model): """ MetaModel_EOS_model is a class to interpolate EOS data with a meta-model. Args: Interpolate_EOS_model (object): Base class of interpolation EOS data. """ def __init__( self, kappas: tuple[Float, Float, Float, Float, Float, Float] = ( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ), v_nq: list[float] = [0.0, 0.0, 0.0, 0.0, 0.0], b_sat: Float = 17.0, b_sym: Float = 25.0, # density parameters nsat: Float = 0.16, nmin_MM_nsat: Float = 0.12 / 0.16, nmax_nsat: Float = 12, ndat: Int = 200, # crust parameters crust_name: str = "DH", max_n_crust_nsat: Float = 0.5, ndat_spline: Int = 10, # proton fraction proton_fraction: bool | float | None = None, ): """ Initialize the MetaModel_EOS_model with the provided coefficients and compute auxiliary data. Main reference for coefficients: PHYSICAL REVIEW C 103, 045803 (2021) Args: kappas (tuple[Float, Float, Float, Float, Float, Float], optional): The coefficients for the saturation part of the metamodel part of the EOS. Defaults to (0.0, 0.0, 0.0, 0.0, 0.0, 0.0). v_nq (list[float], optional): The coefficients for the symmetry part of the metamodel part of the EOS. Defaults to [0.0, 0.0, 0.0, 0.0, 0.0]. b_sat (Float, optional): The saturation coefficient for the metamodel part of the EOS. Defaults to 17.0. b_sym (Float, optional): The symmetry coefficient for the metamodel part of the EOS. Defaults to 25.0. nsat (Float, optional): Saturation density. Defaults to 0.16 fm^-3. nmin_MM_nsat (Float, optional): Starting point of densities in units of nsat for the metamodel part of the EOS. Defaults to 0.12 / 0.16. nmax_nsat (Float, optional): End point of densities in units of nsat for the metamodel part of the EOS. Defaults to 12. ndat (Int, optional): Number of datapoints to be used for the metamodel part of the EOS. Defaults to 200. crust_name (str, optional): Name of the crust file to load from crust directory or a filename if a file outside of jester is supplied. max_n_crust_nsat (Float, optional): Maximum number density in units of nsat for crust data loading and interpolation. Defaults to 0.5. ndat_spline (Int, optional): Number of points for cubic spline interpolation in connection region between crust and metamodel parts of the EOS. Defaults to 10. """ # Save as attributes self.nsat = nsat self.v_nq = jnp.array(v_nq) self.b_sat = b_sat self.b_sym = b_sym self.N = 4 # TODO: this is fixed in the metamodeling papers, but we might want to extend this in the future self.nmin_MM_nsat = nmin_MM_nsat self.nmax_nsat = nmax_nsat self.ndat = ndat self.max_n_crust_nsat = max_n_crust_nsat self.ndat_spline = ndat_spline if isinstance(proton_fraction, float): self.proton_fraction_val = proton_fraction self.proton_fraction = lambda x, y: self.proton_fraction_val print(f"Proton fraction fixed to {self.proton_fraction_val}") else: self.proton_fraction = lambda x, y: self.compute_proton_fraction(x, y) # Constructions assert ( len(kappas) == 6 ), "kappas must be a tuple of 6 values: kappa_sat, kappa_sat2, kappa_sat3, kappa_NM, kappa_NM2, kappa_NM3" ( self.kappa_sat, self.kappa_sat2, self.kappa_sat3, self.kappa_NM, self.kappa_NM2, self.kappa_NM3, ) = kappas self.kappa_sym = self.kappa_NM - self.kappa_sat self.kappa_sym2 = self.kappa_NM2 - self.kappa_sat2 self.kappa_sym3 = self.kappa_NM3 - self.kappa_sat3 # t_sat or TFGsat is the kinetic energy per nucleons in SM and at saturation, see just after eq (13) in the Margueron paper self.t_sat = ( 3 * utils.hbar**2 / (10 * utils.m) * (3 * jnp.pi**2 * self.nsat / 2) ** (2 / 3) ) # v_sat is defined in equations (22) - (26) in the Margueron et al. paper self.v_sat_0_no_NEP = -self.t_sat * ( 1 + self.kappa_sat + self.kappa_sat2 + self.kappa_sat3 ) self.v_sat_1_no_NEP = -self.t_sat * ( 2 + 5 * self.kappa_sat + 8 * self.kappa_sat2 + 11 * self.kappa_sat3 ) self.v_sat_2_no_NEP = ( -2 * self.t_sat * (-1 + 5 * self.kappa_sat + 20 * self.kappa_sat2 + 44 * self.kappa_sat3) ) self.v_sat_3_no_NEP = ( -2 * self.t_sat * (4 - 5 * self.kappa_sat + 40 * self.kappa_sat2 + 220 * self.kappa_sat3) ) self.v_sat_4_no_NEP = ( -8 * self.t_sat * (-7 + 5 * self.kappa_sat - 10 * self.kappa_sat2 + 110 * self.kappa_sat3) ) self.v_sym2_0_no_NEP = ( -self.t_sat * ( 2 ** (2 / 3) * (1 + self.kappa_NM + self.kappa_NM2 + self.kappa_NM3) - (1 + self.kappa_sat + self.kappa_sat2 + self.kappa_sat3) ) - self.v_nq[0] ) self.v_sym2_1_no_NEP = ( -self.t_sat * ( 2 ** (2 / 3) * (2 + 5 * self.kappa_NM + 8 * self.kappa_NM2 + 11 * self.kappa_NM3) - (2 + 5 * self.kappa_sat + 8 * self.kappa_sat2 + 11 * self.kappa_sat3) ) - self.v_nq[1] ) self.v_sym2_2_no_NEP = ( -2 * self.t_sat * ( 2 ** (2 / 3) * (-1 + 5 * self.kappa_NM + 20 * self.kappa_NM2 + 44 * self.kappa_NM3) - ( -1 + 5 * self.kappa_sat + 20 * self.kappa_sat2 + 44 * self.kappa_sat3 ) ) - self.v_nq[2] ) self.v_sym2_3_no_NEP = ( -2 * self.t_sat * ( 2 ** (2 / 3) * (4 - 5 * self.kappa_NM + 40 * self.kappa_NM2 + 220 * self.kappa_NM3) - ( 4 - 5 * self.kappa_sat + 40 * self.kappa_sat2 + 220 * self.kappa_sat3 ) ) - self.v_nq[3] ) self.v_sym2_4_no_NEP = ( -8 * self.t_sat * ( 2 ** (2 / 3) * (-7 + 5 * self.kappa_NM - 10 * self.kappa_NM2 + 110 * self.kappa_NM3) - ( -7 + 5 * self.kappa_sat - 10 * self.kappa_sat2 + 110 * self.kappa_sat3 ) ) - self.v_nq[4] ) # Load and preprocess the crust ns_crust, ps_crust, es_crust = load_crust(crust_name) max_n_crust = max_n_crust_nsat * nsat mask = ns_crust <= max_n_crust self.ns_crust, self.ps_crust, self.es_crust = ( ns_crust[mask], ps_crust[mask], es_crust[mask], ) self.mu_lowest = (es_crust[0] + ps_crust[0]) / ns_crust[0] self.cs2_crust = jnp.gradient(self.ps_crust, self.es_crust) # Make sure the metamodel starts above the crust self.max_n_crust = self.ns_crust[-1] # Create density arrays self.nmax = nmax_nsat * self.nsat self.ndat = ndat self.nmin_MM = self.nmin_MM_nsat * self.nsat self.n_metamodel = jnp.linspace( self.nmin_MM, self.nmax, self.ndat, endpoint=False ) self.ns_spline = jnp.append(self.ns_crust, self.n_metamodel) self.n_connection = jnp.linspace( self.max_n_crust + 1e-5, self.nmin_MM, self.ndat_spline, endpoint=False )
[docs] def construct_eos(self, NEP_dict: dict) -> tuple: """ Construct the EOS. Args: NEP_dict (dict): Dictionary with the NEP keys to be passed to the metamodel EOS class. Returns: tuple: EOS quantities (see Interpolate_EOS_model), as well as the chemical potential and speed of sound. """ E_sat = NEP_dict.get( "E_sat", -16.0 ) # NOTE: this is a commong default value, therefore not zero! K_sat = NEP_dict.get("K_sat", 0.0) Q_sat = NEP_dict.get("Q_sat", 0.0) Z_sat = NEP_dict.get("Z_sat", 0.0) E_sym = NEP_dict.get("E_sym", 0.0) L_sym = NEP_dict.get("L_sym", 0.0) K_sym = NEP_dict.get("K_sym", 0.0) Q_sym = NEP_dict.get("Q_sym", 0.0) Z_sym = NEP_dict.get("Z_sym", 0.0) # Add the first derivative coefficient in Esat to make it work with jax.numpy.polyval coefficient_sat = jnp.array([E_sat, 0.0, K_sat, Q_sat, Z_sat]) coefficient_sym = jnp.array([E_sym, L_sym, K_sym, Q_sym, Z_sym]) # Get the coefficents index array and get coefficients index_sat = jnp.arange(len(coefficient_sat)) index_sym = jnp.arange(len(coefficient_sym)) coefficient_sat = coefficient_sat / factorial(index_sat) coefficient_sym = coefficient_sym / factorial(index_sym) # Potential energy (v_sat is defined in equations (22) - (26) in the Margueron et al. paper) v_sat = jnp.array( [ E_sat + self.v_sat_0_no_NEP, 0.0 + self.v_sat_1_no_NEP, K_sat + self.v_sat_2_no_NEP, Q_sat + self.v_sat_3_no_NEP, Z_sat + self.v_sat_4_no_NEP, ] ) # v_sym2 is defined in equations (27) to (31) in the Margueron et al. paper v_sym2 = jnp.array( [ E_sym + self.v_sym2_0_no_NEP, L_sym + self.v_sym2_1_no_NEP, K_sym + self.v_sym2_2_no_NEP, Q_sym + self.v_sym2_3_no_NEP, Z_sym + self.v_sym2_4_no_NEP, ] ) # Auxiliaries first x = self.compute_x(self.n_metamodel) proton_fraction = self.proton_fraction(coefficient_sym, self.n_metamodel) delta = 1 - 2 * proton_fraction f_1 = self.compute_f_1(delta) f_star = self.compute_f_star(delta) f_star2 = self.compute_f_star2(delta) f_star3 = self.compute_f_star3(delta) v = self.compute_v(v_sat, v_sym2, delta) b = self.compute_b(delta) # Other quantities p_metamodel = self.compute_pressure(x, f_1, f_star, f_star2, f_star3, b, v) e_metamodel = self.compute_energy(x, f_1, f_star, f_star2, f_star3, b, v) # Get cs2 for the metamodel cs2_metamodel = self.compute_cs2( self.n_metamodel, p_metamodel, e_metamodel, x, delta, f_1, f_star, f_star2, f_star3, b, v, ) # Spline for speed of sound for the connection region cs2_spline = jnp.append(self.cs2_crust, cs2_metamodel) cs2_connection = utils.cubic_spline( self.n_connection, self.ns_spline, cs2_spline ) cs2_connection = jnp.clip(cs2_connection, 1e-5, 1.0) # Concatenate the arrays n = jnp.concatenate([self.ns_crust, self.n_connection, self.n_metamodel]) cs2 = jnp.concatenate([self.cs2_crust, cs2_connection, cs2_metamodel]) # Make sure the cs2 stays within the physical limits cs2 = jnp.clip(cs2, 1e-5, 1.0) # Compute pressure and energy from chemical potential and initialize the parent class with it log_mu = utils.cumtrapz(cs2, jnp.log(n)) + jnp.log(self.mu_lowest) mu = jnp.exp(log_mu) p = utils.cumtrapz(cs2 * mu, n) + self.ps_crust[0] e = mu * n - p ns, ps, hs, es, dloge_dlogps = self.interpolate_eos(n, p, e) return ns, ps, hs, es, dloge_dlogps, mu, cs2
################# ### AUXILIARY ### #################
[docs] def u(self, x: Array, b: Array, alpha: Int): return 1 - ((-3 * x) ** (self.N + 1 - alpha) * jnp.exp(-b * (1 + 3 * x)))
[docs] def compute_x(self, n: Array): return (n - self.nsat) / (3 * self.nsat)
[docs] def compute_b(self, delta: Array | float): return self.b_sat + self.b_sym * delta**2
[docs] def compute_f_1(self, delta: Array | float): return (1 + delta) ** (5 / 3) + (1 - delta) ** (5 / 3)
[docs] def compute_f_star(self, delta: Array | float): return (self.kappa_sat + self.kappa_sym * delta) * (1 + delta) ** (5 / 3) + ( self.kappa_sat - self.kappa_sym * delta ) * (1 - delta) ** (5 / 3)
[docs] def compute_f_star2(self, delta: Array | float): return (self.kappa_sat2 + self.kappa_sym2 * delta) * (1 + delta) ** (5 / 3) + ( self.kappa_sat2 - self.kappa_sym2 * delta ) * (1 - delta) ** (5 / 3)
[docs] def compute_f_star3(self, delta: Array | float): return (self.kappa_sat3 + self.kappa_sym3 * delta) * (1 + delta) ** (5 / 3) + ( self.kappa_sat3 - self.kappa_sym3 * delta ) * (1 - delta) ** (5 / 3)
[docs] def compute_v(self, v_sat: Array, v_sym2: Array, delta: Array | float) -> Array: return jnp.array( [ v_sat[alpha] + v_sym2[alpha] * delta**2 + self.v_nq[alpha] * delta**4 for alpha in range(self.N + 1) ] )
[docs] def compute_energy( self, x: Array, f_1: Array, f_star: Array, f_star2: Array, f_star3: Array, b: Array, v: Array, ) -> Array: prefac = self.t_sat / 2 * (1 + 3 * x) ** (2 / 3) linear = (1 + 3 * x) * f_star quadratic = (1 + 3 * x) ** 2 * f_star2 cubic = (1 + 3 * x) ** 3 * f_star3 kinetic_energy = prefac * (f_1 + linear + quadratic + cubic) # Potential energy # TODO: a bit cumbersome, find another way, like jax tree map? potential_energy = 0 for alpha in range(5): u = self.u(x, b, alpha) potential_energy += v.at[alpha].get() / (factorial(alpha)) * x**alpha * u return kinetic_energy + potential_energy
[docs] def esym(self, coefficient_sym: list, x: Array): # TODO: change this to be self-consistent: see Rahul's approach for that. return jnp.polyval(jnp.array(coefficient_sym[::-1]), x)
[docs] def compute_pressure( self, x: Array, f_1: Array, f_star: Array, f_star2: Array, f_star3: Array, b: Array, v: Array, ) -> Array: # TODO: currently only for ELFc! p_kin = ( 1 / 3 * self.nsat * self.t_sat * (1 + 3 * x) ** (5 / 3) * ( f_1 + 5 / 2 * (1 + 3 * x) * f_star + 4 * (1 + 3 * x) ** 2 * f_star2 + 11 / 2 * (1 + 3 * x) ** 3 * f_star3 ) ) # TODO: cumbersome with jnp.array, find another way p_pot = 0 for alpha in range(1, 5): u = self.u(x, b, alpha) fac1 = alpha * u fac2 = (self.N + 1 - alpha - 3 * b * x) * (u - 1) p_pot += ( v.at[alpha].get() / (factorial(alpha)) * x ** (alpha - 1) * (fac1 + fac2) ) p_pot = p_pot - v.at[0].get() * (-3) ** (self.N + 1) * x**self.N * ( self.N + 1 - 3 * b * x ) * jnp.exp(-b * (1 + 3 * x)) p_pot = p_pot * (1 / 3) * self.nsat * (1 + 3 * x) ** 2 return p_pot + p_kin
[docs] def compute_cs2( self, n: Array, p: Array, e: Array, x: Array, delta: Array | float, f_1: Array, f_star: Array, f_star2: Array, f_star3: Array, b: Array, v: Array, ): ### Compute incompressibility # Kinetic part K_kin = ( self.t_sat * (1 + 3 * x) ** (2 / 3) * ( -f_1 + 5 * (1 + 3 * x) * f_star + 20 * (1 + 3 * x) ** 2 * f_star2 + 44 * (1 + 3 * x) ** 3 * f_star3 ) ) # Potential part K_pot = 0 for alpha in range(2, self.N + 1): u = 1 - ((-3 * x) ** (self.N + 1 - alpha) * jnp.exp(-b * (1 + 3 * x))) x_up = (self.N + 1 - alpha - 3 * b * x) * (u - 1) x2_upp = ( -(self.N + 1 - alpha) * (self.N - alpha) + 6 * b * x * (self.N + 1 - alpha) - 9 * x**2 * b**2 ) * (1 - u) K_pot = K_pot + v.at[alpha].get() / (factorial(alpha)) * x ** ( alpha - 2 ) * (alpha * (alpha - 1) * u + 2 * alpha * x_up + x2_upp) K_pot += ( v.at[0].get() * (-(self.N + 1) * (self.N) + 6 * b * x * (self.N + 1) - 9 * x**2 * b**2) * ((-3) ** (self.N + 1) * x ** (self.N - 1) * jnp.exp(-b * (1 + 3 * x))) ) K_pot += ( 2 * v.at[1].get() * (self.N - 3 * b * x) * (-((-3) ** (self.N)) * x ** (self.N - 1) * jnp.exp(-b * (1 + 3 * x))) ) K_pot += ( v.at[1].get() * (-(self.N) * (self.N - 1) + 6 * b * x * (self.N) - 9 * x**2 * b**2) * ((-3) ** (self.N) * x ** (self.N - 1) * jnp.exp(-b * (1 + 3 * x))) ) K_pot *= (1 + 3 * x) ** 2 K = K_kin + K_pot + 18 / n * p # For electron K_Fb = (3.0 * jnp.pi**2 / 2.0 * n) ** (1.0 / 3.0) * utils.hbarc K_Fe = K_Fb * (1.0 - delta) ** (1.0 / 3.0) C = utils.m_e**4 / (8.0 * jnp.pi**2) / utils.hbarc**3 x = K_Fe / utils.m_e f = x * (1 + 2 * x**2) * jnp.sqrt(1 + x**2) - jnp.arcsinh(x) e_electron = C * f p_electron = -e_electron + 8.0 / 3.0 * C * x**3 * jnp.sqrt(1 + x**2) K_electron = 8 * C / n * x**3 * (3 + 4 * x**2) / ( jnp.sqrt(1 + x**2) ) - 9 / n * (e_electron + p_electron) # Sum together: K_tot = K + K_electron # Finally, get cs2: chi = K_tot / 9.0 total_energy_density = (e + utils.m) * n + e_electron total_pressure = p + p_electron h_tot = (total_energy_density + total_pressure) / n cs2 = chi / h_tot return cs2
[docs] def compute_proton_fraction( self, coefficient_sym: list, n: Array ) -> Float[Array, "n_points"]: """ Computes the proton fraction for a given number density. Args: n (Float[Array, "n_points"]): Number density in fm^-3. Returns: Float[Array, "n_points"]: Proton fraction as a function of the number density. """ # TODO: the following comments should be in the doc string # # chemical potential of electron -- derivation # mu_e = hbarc * pow(3 * pi**2 * x * n, 1. / 3.) # = hbarc * pow(3 * pi**2 * n, 1. / 3.) * y (y = x**1./3.) # mu_p - mu_n = dEdx # = -4 * Esym * (1. - 2. * x) # = -4 * Esym + 8 * Esym * y**3 # at beta equilibrium, the polynominal is given by # mu_e(y) + dEdx(y) - (m_n - m_p) = 0 # p_0 = -4 * Esym - (m_n - m_p) # p_1 = hbarc * pow(3 * pi**2 * n, 1. / 3.) # p_2 = 0 # p_3 = 8 * Esym Esym = self.esym(coefficient_sym, n) a = 8.0 * Esym b = jnp.zeros(shape=n.shape) c = utils.hbarc * jnp.power(3.0 * jnp.pi**2 * n, 1.0 / 3.0) d = -4.0 * Esym - (utils.m_n - utils.m_p) coeffs = jnp.array( [ a, b, c, d, ] ).T ys = utils.cubic_root_for_proton_fraction(coeffs) physical_ys = jnp.where( (ys.imag == 0.0) * (ys.real >= 0.0) * (ys.real <= 1.0), ys.real, jnp.zeros_like(ys.real), ).sum(axis=1) proton_fraction = jnp.power(physical_ys, 3) return proton_fraction
[docs] class MetaModel_with_CSE_EOS_model(Interpolate_EOS_model): """ MetaModel_with_CSE_EOS_model is a class to interpolate EOS data with a meta-model and using the CSE. """ def __init__( self, nsat: Float = 0.16, nmin_MM_nsat: Float = 0.12 / 0.16, nmax_nsat: Float = 12, ndat_metamodel: Int = 100, ndat_CSE: Int = 100, **metamodel_kwargs, ): """ Initialize the MetaModel_with_CSE_EOS_model with the provided coefficients and compute auxiliary data. Args: coefficient_sat (Float[Array, "n_sat_coeff"]): The coefficients for the saturation part of the metamodel part of the EOS. coefficient_sym (Float[Array, "n_sym_coeff"]): The coefficients for the symmetry part of the metamodel part of the EOS. nbreak (Float): The number density at the transition point between the metamodel and the CSE part of the EOS. ngrids (Float[Array, "n_grid_point"]): The number densities for the CSE part of the EOS. cs2grids (Float[Array, "n_grid_point"]): The speed of sound squared for the CSE part of the EOS. nsat (Float, optional): Saturation density. Defaults to 0.16 fm^-3. nmin (Float, optional): Starting point of densities. Defaults to 0.1 fm^-3. nmax (Float, optional): End point of EOS. Defaults to 12*0.16 fm^-3, i.e. 12 nsat. ndat_metamodel (Int, optional): Number of datapoints to be used for the metamodel part of the EOS. Defaults to 1000. ndat_CSE (Int, optional): Number of datapoints to be used for the CSE part of the EOS. Defaults to 1000. """ self.nmax = nmax_nsat * nsat self.ndat_CSE = ndat_CSE self.nsat = nsat self.nmin_MM_nsat = nmin_MM_nsat self.ndat_metamodel = ndat_metamodel self.metamodel_kwargs = metamodel_kwargs
[docs] def construct_eos( self, NEP_dict: dict, ngrids: Float[Array, "n_grid_point"], cs2grids: Float[Array, "n_grid_point"], ) -> tuple: """ Construct the EOS Args: NEP_dict (dict): Dictionary with the NEP keys to be passed to the metamodel EOS class. ngrids (Float[Array, `n_grid_point`]): Density grid points of densities for the CSE part of the EOS. cs2grids (Float[Array, `n_grid_point`]): Speed-of-sound squared grid points of densities for the CSE part of the EOS. Returns: tuple: EOS quantities (see Interpolate_EOS_model), as well as the chemical potential and speed of sound. """ # Initializate the MetaModel part up to n_break metamodel = MetaModel_EOS_model( nsat=self.nsat, nmin_MM_nsat=self.nmin_MM_nsat, nmax_nsat=NEP_dict["nbreak"] / self.nsat, ndat=self.ndat_metamodel, **self.metamodel_kwargs, ) # Construct the metamodel part: mm_output = metamodel.construct_eos(NEP_dict) n_metamodel, p_metamodel, _, e_metamodel, _, mu_metamodel, cs2_metamodel = ( mm_output ) # Convert units back for CSE initialization n_metamodel = n_metamodel / utils.fm_inv3_to_geometric p_metamodel = p_metamodel / utils.MeV_fm_inv3_to_geometric e_metamodel = e_metamodel / utils.MeV_fm_inv3_to_geometric # Get values at break density p_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, p_metamodel) e_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, e_metamodel) mu_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, mu_metamodel) cs2_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, cs2_metamodel) # Define the speed-of-sound interpolation of the extension portion ngrids = jnp.concatenate((jnp.array([NEP_dict["nbreak"]]), ngrids)) cs2grids = jnp.concatenate((jnp.array([cs2_break]), cs2grids)) cs2_extension_function = lambda n: jnp.interp(n, ngrids, cs2grids) # Compute n, p, e for CSE (number densities in unit of fm^-3) n_CSE = jnp.logspace( jnp.log10(NEP_dict["nbreak"]), jnp.log10(self.nmax), num=self.ndat_CSE ) cs2_CSE = cs2_extension_function(n_CSE) # We add a very small number to avoid problems with duplicates below mu_CSE = mu_break * jnp.exp(utils.cumtrapz(cs2_CSE / n_CSE, n_CSE)) + 1e-6 p_CSE = p_break + utils.cumtrapz(cs2_CSE * mu_CSE, n_CSE) + 1e-6 e_CSE = e_break + utils.cumtrapz(mu_CSE, n_CSE) + 1e-6 # Combine metamodel and CSE data n = jnp.concatenate((n_metamodel, n_CSE)) p = jnp.concatenate((p_metamodel, p_CSE)) e = jnp.concatenate((e_metamodel, e_CSE)) # TODO: let's decide whether we want to save cs2 and mu or just use them for computation and then discard them. mu = jnp.concatenate((mu_metamodel, mu_CSE)) cs2 = jnp.concatenate((cs2_metamodel, cs2_CSE)) ns, ps, hs, es, dloge_dlogps = self.interpolate_eos(n, p, e) return ns, ps, hs, es, dloge_dlogps, mu, cs2
[docs] class MetaModel_with_peakCSE_EOS_model(Interpolate_EOS_model): """ MetaModel_with_peakCSE_EOS_model is a class to interpolate EOS data with a meta-model and using the CSE. The parametrization of the CSE is based on the peakCSE model, which is a Gaussian peak with a logit growth rate, in order to guarantee consistency with pQCD at the highest densities. Args: Interpolate_EOS_model (object): Base class of interpolation EOS data. """ def __init__( self, nsat: Float = 0.16, nmin_MM_nsat: Float = 0.12 / 0.16, nmax_nsat: Float = 12, ndat_metamodel: Int = 100, ndat_CSE: Int = 100, **metamodel_kwargs, ): """ Initialize the MetaModel_with_peakCSE_EOS_model with the provided coefficients and compute auxiliary data. Args: coefficient_sat (Float[Array, "n_sat_coeff"]): The coefficients for the saturation part of the metamodel part of the EOS. coefficient_sym (Float[Array, "n_sym_coeff"]): The coefficients for the symmetry part of the metamodel part of the EOS. nbreak (Float): The number density at the transition point between the metamodel and the CSE part of the EOS. ngrids (Float[Array, "n_grid_point"]): The number densities for the CSE part of the EOS. cs2grids (Float[Array, "n_grid_point"]): The speed of sound squared for the CSE part of the EOS. nsat (Float, optional): Saturation density. Defaults to 0.16 fm^-3. nmin (Float, optional): Starting point of densities. Defaults to 0.1 fm^-3. nmax (Float, optional): End point of EOS. Defaults to 12*0.16 fm^-3, i.e. 12 nsat. ndat_metamodel (Int, optional): Number of datapoints to be used for the metamodel part of the EOS. Defaults to 1000. ndat_CSE (Int, optional): Number of datapoints to be used for the CSE part of the EOS. Defaults to 1000. """ # TODO: align with new metamodel code self.nmax = nmax_nsat * nsat self.ndat_CSE = ndat_CSE self.nsat = nsat self.nmin_MM_nsat = nmin_MM_nsat self.ndat_metamodel = ndat_metamodel self.metamodel_kwargs = metamodel_kwargs
[docs] def construct_eos(self, NEP_dict: dict, peakCSE_dict: dict): # Initializate the MetaModel part up to n_break metamodel = MetaModel_EOS_model( nsat=self.nsat, nmin_MM_nsat=self.nmin_MM_nsat, nmax_nsat=NEP_dict["nbreak"] / self.nsat, ndat=self.ndat_metamodel, **self.metamodel_kwargs, ) # Construct the metamodel part: mm_output = metamodel.construct_eos(NEP_dict) n_metamodel, p_metamodel, _, e_metamodel, _, mu_metamodel, cs2_metamodel = ( mm_output ) # Convert units back for CSE initialization n_metamodel = n_metamodel / utils.fm_inv3_to_geometric p_metamodel = p_metamodel / utils.MeV_fm_inv3_to_geometric e_metamodel = e_metamodel / utils.MeV_fm_inv3_to_geometric # Get values at break density p_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, p_metamodel) e_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, e_metamodel) mu_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, mu_metamodel) cs2_break = jnp.interp(NEP_dict["nbreak"], n_metamodel, cs2_metamodel) # Define the speed-of-sound of the extension portion # the model is taken from arXiv:1812.08188 # but instead of energy density, I am using density as the input cs2_extension_function = lambda x: ( peakCSE_dict["gaussian_peak"] * jnp.exp( -0.5 * ( (x - peakCSE_dict["gaussian_mu"]) ** 2 / peakCSE_dict["gaussian_sigma"] ** 2 ) ) + cs2_break + ( (1.0 / 3.0 - cs2_break) / ( 1.0 + jnp.exp( -peakCSE_dict["logit_growth_rate"] * (x - peakCSE_dict["logit_midpoint"]) ) ) ) ) # Compute n, p, e for peakCSE (number densities in unit of fm^-3) n_CSE = jnp.logspace( jnp.log10(NEP_dict["nbreak"]), jnp.log10(self.nmax), num=self.ndat_CSE ) cs2_CSE = cs2_extension_function(n_CSE) # We add a very small number to avoid problems with duplicates below mu_CSE = mu_break * jnp.exp(utils.cumtrapz(cs2_CSE / n_CSE, n_CSE)) + 1e-6 p_CSE = p_break + utils.cumtrapz(cs2_CSE * mu_CSE, n_CSE) + 1e-6 e_CSE = e_break + utils.cumtrapz(mu_CSE, n_CSE) + 1e-6 # Combine metamodel and CSE data n = jnp.concatenate((n_metamodel, n_CSE)) p = jnp.concatenate((p_metamodel, p_CSE)) e = jnp.concatenate((e_metamodel, e_CSE)) # TODO: let's decide whether we want to save cs2 and mu or just use them for computation and then discard them. mu = jnp.concatenate((mu_metamodel, mu_CSE)) cs2 = jnp.concatenate((cs2_metamodel, cs2_CSE)) # # FIXME: this is pretty experimental, but we have duplicates which will break TOV solver but are hard to remove in a JIT-compatible manner. Note that we should perhaps do something similar in the metamodel EOS. # for array_to_check in [n, p, e]: # indices = jnp.where(jnp.diff(array_to_check) <= 0.0)[0][0] # print(indices) # print(f"n at duplicates +/- 1: {n[indices-1:indices+1] /0.16} nsat") # n = jnp.unique(n) # e = jnp.unique(e) # p = jnp.unique(p) ns, ps, hs, es, dloge_dlogps = self.interpolate_eos(n, p, e) return ns, ps, hs, es, dloge_dlogps, mu, cs2
[docs] def locate_lowest_non_causal_point(cs2): # TODO: we might want to move this inside utils? # Create a boolean mask where the value equals 1 mask = cs2 >= 1.0 # If no element equals 1, we want to return -1 or some indicator # First, check if any element equals 1 any_ones = jnp.any(mask) # Find the index of the first True value in the mask # argmax returns the first index of the maximum value # Since our mask is boolean, the first True will be the first 1 indices = jnp.arange(len(cs2)) masked_indices = jnp.where(mask, indices, len(cs2)) first_index = jnp.min(masked_indices) # Return -1 if no element equals 1, otherwise return the found index return jnp.where(any_ones, first_index, -1)
[docs] def construct_family(eos: tuple, ndat: Int = 50, min_nsat: Float = 2) -> tuple[ Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], ]: """ Solve the TOV equations and generate the M, R and Lambda curves for the given EOS. Args: eos (tuple): Tuple of the EOS data (ns, ps, hs, es). ndat (int, optional): Number of datapoints used when constructing the central pressure grid. Defaults to 50. min_nsat (int, optional): Starting density for central pressure in numbers of nsat (assumed to be 0.16 fm^-3). Defaults to 2. Returns: tuple[Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"]]: log(pcs), masses in solar masses, radii in km, and dimensionless tidal deformabilities """ # Construct the dictionary ns, ps, hs, es, dloge_dlogps = eos eos_dict = dict(p=ps, h=hs, e=es, dloge_dlogp=dloge_dlogps) # calculate the pc_min pc_min = utils.interp_in_logspace( min_nsat * 0.16 * utils.fm_inv3_to_geometric, ns, ps ) # end at pc at pmax at which it is causal cs2 = ps / es / dloge_dlogps pc_max = eos_dict["p"][locate_lowest_non_causal_point(cs2)] pcs = jnp.logspace(jnp.log10(pc_min), jnp.log10(pc_max), num=ndat) def solve_single_pc(pc): """Solve for single pc value""" return tov.tov_solver(eos_dict, pc) ms, rs, ks = jax.vmap(solve_single_pc)(pcs) # calculate the compactness cs = ms / rs # convert the mass to solar mass and the radius to km ms /= utils.solar_mass_in_meter rs /= 1e3 # calculate the tidal deformability lambdas = 2.0 / 3.0 * ks * jnp.power(cs, -5.0) # Limit masses to be below MTOV pcs, ms, rs, lambdas = utils.limit_by_MTOV(pcs, ms, rs, lambdas) # Get a mass grid and interpolate, since we might have dropped provided some duplicate points mass_grid = jnp.linspace(jnp.min(ms), jnp.max(ms), ndat) rs = jnp.interp(mass_grid, ms, rs) lambdas = jnp.interp(mass_grid, ms, lambdas) pcs = jnp.interp(mass_grid, ms, pcs) ms = mass_grid return jnp.log(pcs), ms, rs, lambdas
[docs] def construct_family_nonGR(eos: tuple, ndat: Int = 50, min_nsat: Float = 2) -> tuple[ Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], ]: """ Solve the post-TOV equations and generate the M, R and Lambda curves. Args: eos (tuple): Tuple of the EOS data (ns, ps, hs, es). ndat (int, optional): Number of datapoints used when constructing the central pressure grid. Defaults to 50. min_nsat (int, optional): Starting density for central pressure in numbers of nsat (assumed to be 0.16 fm^-3). Defaults to 2. Returns: tuple[Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"], Float[Array, "ndat"]]: log(pcs), masses in solar masses, radii in km, and dimensionless tidal deformabilities """ # Construct the dictionary ( ns, ps, hs, es, dloge_dlogps, alpha, beta, gamma, lambda_BL, lambda_DY, lambda_HB, ) = eos eos_dict = dict( p=ps, h=hs, e=es, dloge_dlogp=dloge_dlogps, alpha=alpha, beta=beta, gamma=gamma, lambda_BL=lambda_BL, lambda_DY=lambda_DY, lambda_HB=lambda_HB, ) # calculate the pc_min pc_min = utils.interp_in_logspace( min_nsat * 0.16 * utils.fm_inv3_to_geometric, ns, ps ) # end at pc at pmax at which it is causal cs2 = ps / es / dloge_dlogps pc_max = eos_dict["p"][locate_lowest_non_causal_point(cs2)] pcs = jnp.logspace(jnp.log10(pc_min), jnp.log10(pc_max), num=ndat) def solve_single_pc(pc): """Solve for single pc value""" return ptov.tov_solver(eos_dict, pc) ms, rs, ks = jax.vmap(solve_single_pc)(pcs) ### TODO: Check the timing with respect to this implementation # ms, rs, ks = jnp.vectorize( # tov.tov_solver, # excluded=[ # 0, # ], # )(eos_dict, pcs) # calculate the compactness cs = ms / rs # convert the mass to solar mass and the radius to km ms /= utils.solar_mass_in_meter rs /= 1e3 # calculate the tidal deformability lambdas = 2.0 / 3.0 * ks * jnp.power(cs, -5.0) # TODO: perhaps put a boolean here to flag whether or not to do this, or do we always want to do this? # Limit masses to be below MTOV pcs, ms, rs, lambdas = utils.limit_by_MTOV(pcs, ms, rs, lambdas) # Get a mass grid and interpolate, since we might have dropped provided some duplicate points mass_grid = jnp.linspace(jnp.min(ms), jnp.max(ms), ndat) rs = jnp.interp(mass_grid, ms, rs) lambdas = jnp.interp(mass_grid, ms, lambdas) pcs = jnp.interp(mass_grid, ms, pcs) ms = mass_grid return jnp.log(pcs), ms, rs, lambdas