Skip to content

Commit

Permalink
diff k per call for gp emu
Browse files Browse the repository at this point in the history
  • Loading branch information
jchavesmontero committed Aug 28, 2024
1 parent 5ad11a3 commit e213a3c
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions lace/emulator/gp_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,9 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None):
except:
length = 1

if k_Mpc.ndim == 1:
k_Mpc = np.repeat(k_Mpc[None, :], length, axis=0)

# only implemented for length == 1
if self.hull:
if length == 1:
Expand All @@ -707,7 +710,13 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None):
kind="cubic",
fill_value="extrapolate",
)
p1d = interpolator(k_Mpc)
p1d = np.zeros_like(k_Mpc)
for ii in range(k_Mpc.shape[0]):
p1d[ii] = interpolator(k_Mpc[ii])

if length == 1:
p1d = p1d[0]

if not return_covar:
return p1d
else:
Expand All @@ -718,29 +727,40 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None):
kind="cubic",
fill_value="extrapolate",
)
p1d_err = err_interp(k_Mpc)
p1d_err = np.zeros_like(k_Mpc)
for ii in range(k_Mpc.shape[0]):
p1d_err[ii] = err_interp(k_Mpc[ii])

if self.emu_per_k:
covar = np.diag(p1d_err**2)
else:
# assume fully correlated errors when using same hyperparams
covar = np.outer(p1d_err, p1d_err)
covar = np.zeros(
(k_Mpc.shape[0], k_Mpc.shape[1], k_Mpc.shape[1])
)
for ii in range(k_Mpc.shape[0]):
covar[ii] = np.outer(p1d_err[ii], p1d_err[ii])

if length == 1:
covar = covar[0]

return p1d, covar

elif self.emu_type == "polyfit":
# gp_pred here are just the coefficients of the polynomial
p1d = np.zeros((gp_pred.shape[0], k_Mpc.shape[0]))
p1d = np.zeros((gp_pred.shape[0], k_Mpc.shape[1]))
for ii in range(gp_pred.shape[0]):
poly = np.poly1d(gp_pred[ii])
p1d[ii] = np.exp(poly(np.log(k_Mpc)))
p1d[ii] = np.exp(poly(np.log(k_Mpc[ii])))

if not return_covar:
if length == 1:
p1d = p1d[0]
return p1d

lk = np.log(k_Mpc)
covar = np.zeros((gp_pred.shape[0], k_Mpc.shape[0], k_Mpc.shape[0]))
covar = np.zeros((gp_pred.shape[0], k_Mpc.shape[1], k_Mpc.shape[1]))
for ii in range(gp_pred.shape[0]):
lk = np.log(k_Mpc[ii])
erry2 = (
(gp_err[ii, 0] * lk**4) ** 2
+ (gp_err[ii, 1] * lk**3) ** 2
Expand Down

0 comments on commit e213a3c

Please sign in to comment.