diff --git a/lace/emulator/gp_emulator.py b/lace/emulator/gp_emulator.py index 851ea15c..0603af6b 100644 --- a/lace/emulator/gp_emulator.py +++ b/lace/emulator/gp_emulator.py @@ -689,6 +689,9 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None): except: length = 1 + if k_Mpc.ndim == 1: + k_Mpc = np.repeat(k_Mpc[None, :], length, axis=0) + # only implemented for length == 1 if self.hull: if length == 1: @@ -707,7 +710,13 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None): kind="cubic", fill_value="extrapolate", ) - p1d = interpolator(k_Mpc) + p1d = np.zeros_like(k_Mpc) + for ii in range(k_Mpc.shape[0]): + p1d[ii] = interpolator(k_Mpc[ii]) + + if length == 1: + p1d = p1d[0] + if not return_covar: return p1d else: @@ -718,29 +727,40 @@ def emulate_p1d_Mpc(self, model, k_Mpc, return_covar=False, z=None): kind="cubic", fill_value="extrapolate", ) - p1d_err = err_interp(k_Mpc) + p1d_err = np.zeros_like(k_Mpc) + for ii in range(k_Mpc.shape[0]): + p1d_err[ii] = err_interp(k_Mpc[ii]) + if self.emu_per_k: covar = np.diag(p1d_err**2) else: # assume fully correlated errors when using same hyperparams - covar = np.outer(p1d_err, p1d_err) + covar = np.zeros( + (k_Mpc.shape[0], k_Mpc.shape[1], k_Mpc.shape[1]) + ) + for ii in range(k_Mpc.shape[0]): + covar[ii] = np.outer(p1d_err[ii], p1d_err[ii]) + + if length == 1: + covar = covar[0] + return p1d, covar elif self.emu_type == "polyfit": # gp_pred here are just the coefficients of the polynomial - p1d = np.zeros((gp_pred.shape[0], k_Mpc.shape[0])) + p1d = np.zeros((gp_pred.shape[0], k_Mpc.shape[1])) for ii in range(gp_pred.shape[0]): poly = np.poly1d(gp_pred[ii]) - p1d[ii] = np.exp(poly(np.log(k_Mpc))) + p1d[ii] = np.exp(poly(np.log(k_Mpc[ii]))) if not return_covar: if length == 1: p1d = p1d[0] return p1d - lk = np.log(k_Mpc) - covar = np.zeros((gp_pred.shape[0], k_Mpc.shape[0], k_Mpc.shape[0])) + covar = np.zeros((gp_pred.shape[0], k_Mpc.shape[1], k_Mpc.shape[1])) for ii in range(gp_pred.shape[0]): + lk = np.log(k_Mpc[ii]) erry2 = ( (gp_err[ii, 0] * lk**4) ** 2 + (gp_err[ii, 1] * lk**3) ** 2