Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved CompositeState and friends to dedicated file. Ran black. #105

Merged
merged 8 commits into from
Jan 20, 2024
58 changes: 35 additions & 23 deletions docs/source/fun_figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

sns.set_style("whitegrid")


# %% Banana distribution plot
def banana_plot(ax = None):
def banana_plot(ax=None):
N = 500
x0 = nav.lib.SE2State([0.3, 3, 4], direction="right")
covariance = np.diag([0.2**2, 0.05**2, 0.05**2])
Expand Down Expand Up @@ -38,15 +39,22 @@ def banana_plot(ax = None):

# random greyscale color
color = np.random.uniform(0.3, 0.9)
ax.plot(traj_pos[:, 0], traj_pos[:, 1], color=(color, color, color),zorder=1)
ax.plot(
traj_pos[:, 0],
traj_pos[:, 1],
color=(color, color, color),
zorder=1,
)

# save the final state
final_states.append(x_traj[-1])

final_positions = np.array([x.position for x in final_states])
ax.scatter(final_positions[:, 0], final_positions[:, 1], color="C0", zorder=2)
ax.scatter(
final_positions[:, 0], final_positions[:, 1], color="C0", zorder=2
)

# Propagate the mean with EKF
# Propagate the mean with EKF
kf = nav.ExtendedKalmanFilter(process_model)
x0_hat = nav.StateWithCovariance(x0, covariance)

Expand All @@ -59,10 +67,12 @@ def banana_plot(ax = None):
ax.plot(mean_traj[:, 0], mean_traj[:, 1], color="r", zorder=3, linewidth=3)
ax.set_aspect("equal")


# banana_plot()


# %%
def pose3d_plot(ax = None):
def pose3d_plot(ax=None):
N = 500
x0 = nav.lib.SE3State([0.3, 3, 4, 0, 0, 0], direction="right")
process_model = nav.lib.BodyFrameVelocity(np.zeros(6))
Expand All @@ -78,24 +88,27 @@ def pose3d_plot(ax = None):
x = process_model.evaluate(x, u, dt)
x_traj.append(x.copy())

fig, ax = nav.plot_poses(x_traj, ax = ax)
fig, ax = nav.plot_poses(x_traj, ax=ax)


# pose3d_plot()


# %%
def three_sigma_plot(axs = None):
def three_sigma_plot(axs=None):
dataset = nav.lib.datasets.SimulatedPoseRangingDataset()

estimates = nav.run_filter(
nav.ExtendedKalmanFilter(dataset.process_model),
dataset.get_ground_truth()[0],
np.diag([0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2, 0.1**2]),
dataset.get_input_data(),
dataset.get_measurement_data()
)

results = nav.GaussianResultList.from_estimates(estimates, dataset.get_ground_truth())
dataset.get_measurement_data(),
)

results = nav.GaussianResultList.from_estimates(
estimates, dataset.get_ground_truth()
)

fig, axs = nav.plot_error(results[:, :3], axs=axs)
axs[2].set_xlabel("Time (s)")
Expand All @@ -105,7 +118,6 @@ def three_sigma_plot(axs = None):


if __name__ == "__main__":

# Make one large figure which has all the plots. This will be a 1x3 grid, with the
# last plot itself being a three vertically stacked plots.

Expand All @@ -119,13 +131,11 @@ def three_sigma_plot(axs = None):

# which will be used here:



fig = plt.figure(figsize=(20, 6))
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 1])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1], projection='3d')
ax2 = fig.add_subplot(gs[1], projection="3d")

# The last plot is a 3x1 grid
gs2 = gs[2].subgridspec(3, 1, hspace=0.1)
ax3 = fig.add_subplot(gs2[0])
Expand All @@ -141,25 +151,27 @@ def three_sigma_plot(axs = None):
ax2.set_yticklabels([])
ax2.set_zticklabels([])


banana_plot(ax1)
pose3d_plot(ax2)
three_sigma_plot(np.array([ax3, ax4, ax5]))

# Set spacing to the above values
# Set spacing to the above values
fig.subplots_adjust(
top=0.975,
bottom=0.097,
left=0.025,
right=0.992,
hspace=0.2,
wspace=0.117
wspace=0.117,
)


# Save the figure with transparent background, next to this file
import os
fig.savefig(os.path.join(os.path.dirname(__file__), "fun_figs.png"), transparent=True)
import os

fig.savefig(
os.path.join(os.path.dirname(__file__), "fun_figs.png"),
transparent=True,
)

plt.show()
# %%
8 changes: 7 additions & 1 deletion navlie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@
jacobian,
)

from .composite import (
CompositeState,
CompositeProcessModel,
CompositeMeasurementModel,
CompositeInput,
)

from .lib.states import StampedValue # for backwards compatibility
from .lib.states import StampedValue # for backwards compatibility
Loading
Loading