Skip to content

Commit

Permalink
Add updated transform and inverse_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
sdahdah committed Aug 15, 2023
1 parent 9dd2239 commit eca7d6d
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 223 deletions.
152 changes: 87 additions & 65 deletions pykoop/koopman_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,7 +1741,8 @@ def fit(self,
for _, lf in self.lifting_functions_input_:
X_out_input = lf.fit_transform(
X_out_input,
n_inputs=X_out_input.shape[1],
n_inputs=X_out_input.shape[1] -
(1 if self.episode_feature_ else 0),
episode_feature=self.episode_feature_,
)
# Compute output dimensions for states
Expand Down Expand Up @@ -1787,43 +1788,53 @@ def transform(self, X: np.ndarray) -> np.ndarray:
f'with {self.n_features_in_} features, but '
f'`transform()` called with {X.shape[1]} '
'features.')
# Separate episode feature
if self.episode_feature_:
X_ep = X[:, [0]]
X = X[:, 1:]
# Split state and input
X_state = X[:, :self.n_states_in_]
X_input = X[:, self.n_states_in_:]
# Put back episode feature if needed
if self.episode_feature_:
X_state = np.hstack((
X_ep,
X_state,
))
X_input = np.hstack((
X_ep,
X_input,
))
# Split episodes
episodes = split_episodes(X, episode_feature=self.episode_feature_)
episodes_state = []
episodes_input = []
for (i, X_i) in episodes:
# Split state and input
episodes_state.append((i, X_i[:, :self.n_states_in_]))
episodes_input.append((i, X_i[:, self.n_states_in_:]))
X_state = combine_episodes(
episodes_state,
episode_feature=self.episode_feature_,
)
X_input = combine_episodes(
episodes_input,
episode_feature=self.episode_feature_,
)
# Fit and transform states
X_out_state = X_state
Xt_state = X_state
for _, lf in self.lifting_functions_state_:
X_out_state = lf.transform(X_out_state)
Xt_state = lf.transform(Xt_state)
# Fit and transform inputs
X_out_input = X_input
Xt_input = X_input
for _, lf in self.lifting_functions_input_:
X_out_input = lf.transform(X_out_input)
# Truncate to same shape
n_samples = min(X_out_state.shape[0], X_out_input.shape[0])
if self.episode_feature_:
Xt = np.hstack((
X_out_state[-n_samples:, :],
X_out_input[-n_samples:, 1:],
))
else:
Xt = np.hstack((
X_out_state[-n_samples:, :],
X_out_input[-n_samples:, :],
Xt_input = lf.transform(Xt_input)
# Split up transformed episodes
episodes_t_state = split_episodes(
Xt_state,
episode_feature=self.episode_feature_,
)
episodes_t_input = split_episodes(
Xt_input,
episode_feature=self.episode_feature_,
)
episodes_t_zipped = zip(episodes_t_state, episodes_t_input)
episodes_t = []
for ((i, Xt_state_i), (_, Xt_input_i)) in episodes_t_zipped:
# Truncate to same shape
n_samples = min(Xt_state_i.shape[0], Xt_input_i.shape[0])
Xt_i = np.hstack((
Xt_state_i[-n_samples:, :],
Xt_input_i[-n_samples:, :],
))
episodes_t.append((i, Xt_i))
Xt = combine_episodes(
episodes_t,
episode_feature=self.episode_feature_,
)
return Xt

def inverse_transform(self, X: np.ndarray) -> np.ndarray:
Expand All @@ -1836,42 +1847,53 @@ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
f'{self.n_features_out_} features, but '
'`inverse_transform()` called with '
f'{X.shape[1]} features.')
if self.episode_feature_:
X_ep = X[:, [0]]
X = X[:, 1:]
# Split state and input
X_state = X[:, :self.n_states_out_]
X_input = X[:, self.n_states_out_:]
# Put back episode feature if needed
if self.episode_feature_:
X_state = np.hstack((
X_ep,
X_state,
))
X_input = np.hstack((
X_ep,
X_input,
))
# Fit and inverse transform states
X_out_state = X_state
# Split episodes
episodes = split_episodes(X, episode_feature=self.episode_feature_)
episodes_state = []
episodes_input = []
for (i, X_i) in episodes:
# Split state and input
episodes_state.append((i, X_i[:, :self.n_states_out_]))
episodes_input.append((i, X_i[:, self.n_states_out_:]))
X_state = combine_episodes(
episodes_state,
episode_feature=self.episode_feature_,
)
X_input = combine_episodes(
episodes_input,
episode_feature=self.episode_feature_,
)
# Fit and transform states
Xt_state = X_state
for _, lf in self.lifting_functions_state_[::-1]:
X_out_state = lf.inverse_transform(X_out_state)
Xt_state = lf.inverse_transform(Xt_state)
# Fit and transform inputs
X_out_input = X_input
Xt_input = X_input
for _, lf in self.lifting_functions_input_[::-1]:
X_out_input = lf.inverse_transform(X_out_input)
# Truncate to same shape
n_samples = min(X_out_state.shape[0], X_out_input.shape[0])
if self.episode_feature_:
Xt = np.hstack((
X_out_state[-n_samples:, :],
X_out_input[-n_samples:, 1:],
))
else:
Xt = np.hstack((
X_out_state[-n_samples:, :],
X_out_input[-n_samples:, :],
Xt_input = lf.inverse_transform(Xt_input)
# Split up transformed episodes
episodes_t_state = split_episodes(
Xt_state,
episode_feature=self.episode_feature_,
)
episodes_t_input = split_episodes(
Xt_input,
episode_feature=self.episode_feature_,
)
episodes_t_zipped = zip(episodes_t_state, episodes_t_input)
episodes_t = []
for ((i, Xt_state_i), (_, Xt_input_i)) in episodes_t_zipped:
# Truncate to same shape
n_samples = min(Xt_state_i.shape[0], Xt_input_i.shape[0])
Xt_i = np.hstack((
Xt_state_i[-n_samples:, :],
Xt_input_i[-n_samples:, :],
))
episodes_t.append((i, Xt_i))
Xt = combine_episodes(
episodes_t,
episode_feature=self.episode_feature_,
)
return Xt

def n_samples_in(self, n_samples_out: int = 1) -> int:
Expand Down
Loading

0 comments on commit eca7d6d

Please sign in to comment.