-
Notifications
You must be signed in to change notification settings - Fork 0
/
nest_info.py
108 lines (89 loc) · 3.7 KB
/
nest_info.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import typing
from collections import namedtuple
from dataclasses import dataclass
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
Coordinate = namedtuple("Coordinate", ["e", "n"])
@dataclass
class NestInfo:
"""
This is a class for interacting with our "Nest".
It will store things like the name of the nest, the coordinate we launch and recover from,
as well as information specific to this nest - such as places we deliver, called delivery_sites.
"""
LOW_RISK_VALUE = 0
HIGH_RISK_VALUE = 1
KEEP_OUT_VALUE = 2
def __init__(self, config_dict: dict):
self.name = config_dict["nest_name"]
self.nest_coord = Coordinate(
e=config_dict["nest_coord"]["e_coord"],
n=config_dict["nest_coord"]["n_coord"],
)
self.risk_zones = np.load(config_dict["risk_zones_path"])
self.maximum_range = config_dict["maximum_range"]
def display(self, ax: plt.Axes):
# Show the risk map, making sure that the map is in E-N coordinate frame
im = plt.imshow(self.risk_zones.T, cmap="Blues", origin="lower")
# Set legend entries to match the map color scheme
risk_values = [
NestInfo.LOW_RISK_VALUE,
NestInfo.HIGH_RISK_VALUE,
NestInfo.KEEP_OUT_VALUE,
]
risk_colors = [im.cmap(im.norm(value)) for value in risk_values]
handles = []
handles.append(mpatches.Patch(color=risk_colors[0], label="Low Risk"))
handles.append(mpatches.Patch(color=risk_colors[1], label="High Risk"))
handles.append(mpatches.Patch(color=risk_colors[2], label="Keep Out"))
# Plot the nest location as a red triangle
nest_handle = plt.plot(*self.nest_coord, "r^", markersize=10, label="Nest Location")
handles.extend(nest_handle)
return handles
class DeliverySite:
def __init__(self, e: int, n: int, site_id: int, name: str):
self.coord = Coordinate(e=e, n=n)
self.site_id = site_id
self.name = name
self.path = []
color_list = plt.rcParams["axes.prop_cycle"].by_key()["color"]
self.color = color_list[self.site_id % len(color_list)]
def set_path(self, path: typing.List["Coordinate"]) -> None:
for coord in path:
if not isinstance(coord, Coordinate):
raise ValueError(
f"""
The path you are trying to set for site {self.name} is of an invalid type.
You need to provide a list of type Coordinate.
The first invalid type in the path was: {coord}
Full path: {path}
"""
)
# if execution makes it here, we have a valid path type.
self.path = path
def display(self, ax: plt.Axes):
if len(self.path) > 0:
e_coords = [coord.e for coord in self.path]
n_coords = [coord.n for coord in self.path]
plt.plot(e_coords, n_coords, ".-", color=self.color)
handle = ax.plot(*self.coord, "o", color=self.color, label=f"{self.site_id}: {self.name}")
else:
handle = ax.plot(
*self.coord,
"rx",
color=self.color,
label=f"{self.site_id}: {self.name}",
)
return handle
def load_delivery_sites(config_dict: dict) -> typing.List["DeliverySite"]:
delivery_sites = []
for site_info in config_dict.get("delivery_sites", []):
site = DeliverySite(
e=site_info["e_coord"],
n=site_info["n_coord"],
site_id=site_info["site_id"],
name=site_info.get("name"),
)
delivery_sites.append(site)
return delivery_sites