Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add plot_foraging_session #5

Merged
merged 21 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ readme = "README.md"
dynamic = ["version"]

dependencies = [
'numpy'
'numpy',
'matplotlib',
'pydantic'
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/aind_dynamic_foraging_basic_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
__version__ = "0.0.0"

from .foraging_efficiency import compute_foraging_efficiency # noqa: F401
from .plot.plot_foraging_session import plot_foraging_session # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Pydantic data model of foraging session data for shared validation.

Maybe this is an overkill...
"""

from typing import List, Optional

import numpy as np
from pydantic import BaseModel, field_validator, model_validator


class PhotostimData(BaseModel):
"""Photostimulation data"""

trial: List[int]
power: List[float]
stim_epoch: Optional[List[str]] = None

class Config:
"""Allow np.ndarray as input"""

arbitrary_types_allowed = True


class ForagingSessionData(BaseModel):
"""Shared validation for foraging session data"""

choice_history: np.ndarray
reward_history: np.ndarray
p_reward: Optional[np.ndarray] = None
random_number: Optional[np.ndarray] = None
autowater_offered: Optional[np.ndarray] = None
fitted_data: Optional[np.ndarray] = None
photostim: Optional[PhotostimData] = None

class Config:
"""Allow np.ndarray as input"""

arbitrary_types_allowed = True

@field_validator(
"choice_history",
"reward_history",
"p_reward",
"random_number",
"autowater_offered",
"fitted_data",
mode="before",
)
@classmethod
def convert_to_ndarray(cls, v, info):
"""Always convert to numpy array"""
return (
np.array(
v,
dtype=(
"bool"
if info.field_name in ["reward_history", "autowater_offered"] # Turn to bool
else None
),
)
if v is not None
else None
)

@model_validator(mode="after")
def check_all_fields(cls, values): # noqa: C901
"""Check consistency of all fields"""

choice_history = values.choice_history
reward_history = values.reward_history
p_reward = values.p_reward
random_number = values.random_number
autowater_offered = values.autowater_offered
fitted_data = values.fitted_data
photostim = values.photostim

if not np.all(np.isin(choice_history, [0.0, 1.0]) | np.isnan(choice_history)):
raise ValueError("choice_history must contain only 0, 1, or np.nan.")

if choice_history.shape != reward_history.shape:
raise ValueError("choice_history and reward_history must have the same shape.")

if p_reward.shape != (2, len(choice_history)):
raise ValueError("reward_probability must have the shape (2, n_trials)")

if random_number is not None and random_number.shape != p_reward.shape:
raise ValueError("random_number must have the same shape as reward_probability.")

if autowater_offered is not None and autowater_offered.shape != choice_history.shape:
raise ValueError("autowater_offered must have the same shape as choice_history.")

if fitted_data is not None and fitted_data.shape[0] != choice_history.shape[0]:
raise ValueError("fitted_data must have the same length as choice_history.")

if photostim is not None:
if len(photostim.trial) != len(photostim.power):
raise ValueError("photostim.trial must have the same length as photostim.power.")
if photostim.stim_epoch is not None and len(photostim.stim_epoch) != len(
photostim.power
):
raise ValueError(
"photostim.stim_epoch must have the same length as photostim.power."
)

return values
42 changes: 14 additions & 28 deletions src/aind_dynamic_foraging_basic_analysis/foraging_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from aind_dynamic_foraging_basic_analysis.data_model.foraging_session import ForagingSessionData


def compute_foraging_efficiency(
baited: bool,
Expand Down Expand Up @@ -74,38 +76,22 @@ def compute_foraging_efficiency(
reward_optimal_func = _reward_optimal_forager_no_baiting

# Formatting and sanity checks
choice_history = np.array(choice_history, dtype=float) # Convert None to np.nan, if any
reward_history = np.array(reward_history, dtype=float)
p_reward = np.array(p_reward, dtype=float)
random_number = np.array(random_number, dtype=float) if random_number is not None else None
n_trials = len(choice_history)

if not np.all(np.isin(choice_history, [0.0, 1.0]) | np.isnan(choice_history)):
raise ValueError("choice_history must contain only 0, 1, or np.nan.")

if not np.all(np.isin(reward_history, [0.0, 1.0])):
raise ValueError("reward_history must contain only 0 (False) or 1 (True).")

if choice_history.shape != reward_history.shape:
raise ValueError("choice_history and reward_history must have the same shape.")

if p_reward.shape != (2, n_trials):
raise ValueError("reward_probability must have the shape (2, n_trials)")

if random_number is not None and random_number.shape != p_reward.shape:
raise ValueError("random_number must have the same shape as reward_probability.")

if autowater_offered is not None and not autowater_offered.shape == choice_history.shape:
raise ValueError("autowater_offered must have the same shape as choice_history.")
data = ForagingSessionData(
choice_history=choice_history,
reward_history=reward_history,
p_reward=p_reward,
random_number=random_number,
autowater_offered=autowater_offered,
)

# Foraging_efficiency is calculated only on finished AND non-autowater trials
ignored = np.isnan(choice_history)
ignored = np.isnan(data.choice_history)
valid_trials = (~ignored & ~autowater_offered) if autowater_offered is not None else ~ignored

choice_history = choice_history[valid_trials]
reward_history = reward_history[valid_trials]
p_reward = p_reward[:, valid_trials]
random_number = random_number[:, valid_trials] if random_number is not None else None
choice_history = data.choice_history[valid_trials]
reward_history = data.reward_history[valid_trials]
p_reward = data.p_reward[:, valid_trials]
random_number = data.random_number[:, valid_trials] if data.random_number is not None else None

# Compute reward of the optimal forager
reward_optimal, reward_optimal_random_seed = reward_optimal_func(
Expand Down
Loading
Loading