Skip to content

Commit

Permalink
Merge pull request #16 from AllenNeuralDynamics/licks
Browse files Browse the repository at this point in the history
Annotating licks into bouts
  • Loading branch information
alexpiet authored Sep 6, 2024
2 parents 9bc62f2 + 905acc9 commit 9d36f96
Show file tree
Hide file tree
Showing 7 changed files with 504 additions and 220 deletions.
159 changes: 159 additions & 0 deletions src/aind_dynamic_foraging_basic_analysis/licks/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Tools for annotation of lick bouts
df_licks = annotate_lick_bouts(nwb)
df_licks = annotate_rewards(nwb)
df_licks = annotate_cue_response(nwb)
"""

import numpy as np


def annotate_lick_bouts(nwb, bout_threshold=0.7):
"""
returns a dataframe of lick times with annotations
pre_ili, the elapsed time since the last lick (on either side)
post_ili, the time until the next lick (on either side)
bout_start (bool), whether this was the start of a lick bout
bout_end (bool), whether this was the end of a lick bout)
bout_number (int), what lick bout this was a part of
nwb, an nwb-like object with attributes: df_events
bout_threshold is the ILI that determines bout segmentation
"""

if not hasattr(nwb, "df_events"):
print("You need to compute df_events: nwb_utils.create_events_df(nwb)")
return
df_licks = nwb.df_events.query('event in ["right_lick_time","left_lick_time"]').copy()
df_licks.reset_index(drop=True, inplace=True)

# Computing ILI for each lick
df_licks["pre_ili"] = np.concatenate([[np.nan], np.diff(df_licks.timestamps.values)])
df_licks["post_ili"] = np.concatenate([np.diff(df_licks.timestamps.values), [np.nan]])

# Assign licks into bouts
df_licks["bout_start"] = df_licks["pre_ili"] > bout_threshold
df_licks["bout_end"] = df_licks["post_ili"] > bout_threshold
df_licks.loc[df_licks["pre_ili"].isnull(), "bout_start"] = True
df_licks.loc[df_licks["post_ili"].isnull(), "bout_end"] = True
df_licks["bout_number"] = np.cumsum(df_licks["bout_start"])

# Check that bouts start and stop
num_bout_start = df_licks["bout_start"].sum()
num_bout_end = df_licks["bout_end"].sum()
num_bouts = df_licks["bout_number"].max()
assert num_bout_start == num_bout_end, "Bout Starts and Bout Ends don't align"
assert num_bout_start == num_bouts, "Number of bouts is incorrect"

return df_licks


def annotate_rewards(nwb):
"""
Annotates df_licks with which lick triggered each reward
nwb, an nwb-lick object with attributes: df_licks, df_events
"""

LICK_TO_REWARD_TOLERANCE = 0.25

if not hasattr(nwb, "df_events"):
print("You need to compute df_events: nwb_utils.create_events_df(nwb)")
return

# ensure we have df_licks
if not hasattr(nwb, "df_licks"):
print("annotating lick bouts")
nwb.df_licks = annotate_lick_bouts(nwb)

# make a copy of df licks
df_licks = nwb.df_licks.copy()

# set default to false
df_licks["rewarded"] = False

# Iterate right rewards, and find most recent lick within tolerance
right_rewards = nwb.df_events.query('event == "right_reward_delivery_time"').copy()
for index, row in right_rewards.iterrows():
this_reward_lick_times = np.where(
(df_licks.timestamps <= row.timestamps)
& (df_licks.timestamps > (row.timestamps - LICK_TO_REWARD_TOLERANCE))
& (df_licks.event == "right_lick_time")
)[0]
if len(this_reward_lick_times) > 0:
df_licks.at[this_reward_lick_times[-1], "rewarded"] = True
# TODO, should check for licks that happened before the last go cue
# TODO, if we can't find a matching lick, should ensure this is manual or auto water

# Iterate left rewards, and find most recent lick within tolerance
left_rewards = nwb.df_events.query('event == "left_reward_delivery_time"').copy()
for index, row in left_rewards.iterrows():
this_reward_lick_times = np.where(
(df_licks.timestamps <= row.timestamps)
& (df_licks.timestamps > (row.timestamps - LICK_TO_REWARD_TOLERANCE))
& (df_licks.event == "left_lick_time")
)[0]
if len(this_reward_lick_times) > 0:
df_licks.at[this_reward_lick_times[-1], "rewarded"] = True

# Annotate lick bouts as rewarded or unrewarded
x = (
df_licks.groupby("bout_number")
.any("rewarded")
.rename(columns={"rewarded": "bout_rewarded"})["bout_rewarded"]
)
df_licks["bout_rewarded"] = False
temp = df_licks.reset_index().set_index("bout_number").copy()
temp.update(x)
temp = temp.reset_index().set_index("index")
df_licks["bout_rewarded"] = temp["bout_rewarded"]

return df_licks


def annotate_cue_response(nwb):
"""
Annotates df_licks with which lick was immediately after a go cue
nwb, an nwb-lick object with attributes: df_licks, df_events
"""

CUE_TO_LICK_TOLERANCE = 1

if not hasattr(nwb, "df_events"):
print("You need to compute df_events: nwb_utils.create_events_df(nwb)")
return

# ensure we have df_licks
if not hasattr(nwb, "df_licks"):
print("annotating lick bouts")
nwb.df_licks = annotate_lick_bouts(nwb)

# make a copy of df licks
df_licks = nwb.df_licks.copy()

# set default to false
df_licks["cue_response"] = False

# Iterate go cues, and find most recent lick within tolerance
cues = nwb.df_events.query('event == "goCue_start_time"').copy()
for index, row in cues.iterrows():
this_lick_times = np.where(
(df_licks.timestamps > row.timestamps)
& (df_licks.timestamps <= (row.timestamps + CUE_TO_LICK_TOLERANCE))
& ((df_licks.event == "right_lick_time") | (df_licks.event == "left_lick_time"))
)[0]
if len(this_lick_times) > 0:
df_licks.at[this_lick_times[0], "cue_response"] = True

# Annotate lick bouts as cue_responsive, or unresponsive
x = (
df_licks.groupby("bout_number")
.any("cue_response")
.rename(columns={"cue_response": "bout_cue_response"})["bout_cue_response"]
)
df_licks["bout_cue_response"] = False
temp = df_licks.reset_index().set_index("bout_number").copy()
temp.update(x)
temp = temp.reset_index().set_index("index")
df_licks["bout_cue_response"] = temp["bout_cue_response"]

return df_licks
39 changes: 27 additions & 12 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aind_dynamic_foraging_basic_analysis.plot.style import STYLE, FIP_COLORS


def plot_fip_psth_compare_alignments(nwb, alignments, channel, tw=[-4, 4]):
def plot_fip_psth_compare_alignments(nwb, alignments, channel, tw=[-4, 4], censor=True):
"""
Compare the same FIP channel aligned to multiple event types
nwb, nwb object for the session
Expand Down Expand Up @@ -54,12 +54,20 @@ def plot_fip_psth_compare_alignments(nwb, alignments, channel, tw=[-4, 4]):
)
return

censor_times = []
for key in align_dict:
censor_times.append(align_dict[key])
censor_times = np.sort(np.concatenate(censor_times))

align_label = "Time (s)"

fig, ax = plt.subplots()

for alignment in align_dict:
etr = fip_psth_inner_compute(nwb, align_dict[alignment], channel, True, tw)
fip_psth_inner_plot(ax, etr, FIP_COLORS.get(alignment, "k"), alignment)
etr = fip_psth_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times
)
fip_psth_inner_plot(ax, etr, FIP_COLORS.get(alignment, ""), alignment)

plt.legend()
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
Expand All @@ -69,6 +77,7 @@ def plot_fip_psth_compare_alignments(nwb, alignments, channel, tw=[-4, 4]):
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"])
plt.tight_layout()
return fig, ax

Expand All @@ -85,15 +94,13 @@ def plot_fip_psth_compare_channels(
"Iso_1_preprocessed",
"Iso_2_preprocessed",
],
censor=True,
):
"""
TODO, need to censor by next event
todo, clean up plots
todo, figure out a modular system for comparing alignments, and channels
todo, annotate licks into bouts, start of bout, etc
nwb, the nwb object for the session of interest
align should either be a string of the name of an event type in nwb.df_events,
or a list of timepoints
channels should be a list of channel names (strings)
EXAMPLE
********************
plot_fip_psth(nwb, 'goCue_start_time')
Expand All @@ -119,10 +126,10 @@ def plot_fip_psth_compare_channels(

fig, ax = plt.subplots()

colors = [FIP_COLORS.get(c, "k") for c in channels]
colors = [FIP_COLORS.get(c, "") for c in channels]
for dex, c in enumerate(channels):
if c in nwb.fip_df["event"].values:
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw)
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw, censor)
fip_psth_inner_plot(ax, etr, colors[dex], c)
else:
print("No data for channel: {}".format(c))
Expand All @@ -135,6 +142,7 @@ def plot_fip_psth_compare_channels(
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id)
plt.tight_layout()
return fig, ax

Expand All @@ -147,11 +155,16 @@ def fip_psth_inner_plot(ax, etr, color, label):
color, the line color to plot
label, the label for the etr
"""
if color == "":
cmap = plt.get_cmap("tab20")
color = cmap(np.random.randint(20))
ax.fill_between(etr.index, etr.data - etr["sem"], etr.data + etr["sem"], color=color, alpha=0.2)
ax.plot(etr.index, etr.data, color, label=label)
ax.plot(etr.index, etr.data, color=color, label=label)


def fip_psth_inner_compute(nwb, align_timepoints, channel, average, tw=[-1, 1]):
def fip_psth_inner_compute(
nwb, align_timepoints, channel, average, tw=[-1, 1], censor=True, censor_times=None
):
"""
helper function that computes the event triggered response
nwb, nwb object for the session of interest, should have fip_df attribute
Expand All @@ -170,6 +183,8 @@ def fip_psth_inner_compute(nwb, align_timepoints, channel, average, tw=[-1, 1]):
t_start=tw[0],
t_end=tw[1],
output_sampling_rate=40,
censor=censor,
censor_times=censor_times,
)

if average:
Expand Down
Loading

0 comments on commit 9d36f96

Please sign in to comment.