Skip to content

Commit

Permalink
oat:faster linear version of decimate script
Browse files Browse the repository at this point in the history
  • Loading branch information
zoombya committed Oct 23, 2024
1 parent d527509 commit adbba3c
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions analysis/src/oxDNA_analysis_tools/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Union
from oxDNA_analysis_tools.UTILS.oat_multiprocesser import oat_multiprocesser
from oxDNA_analysis_tools.UTILS.RyeReader import describe, get_confs, conf_to_str

import mmap
ComputeContext = namedtuple("ComputeContext", [
"traj_info",
"top_info"
Expand All @@ -21,7 +21,7 @@ def compute(ctx:ComputeContext, chunk_size:int, chunk_id:int):
return out


def decimate(traj:str, outfile:str, ncpus:int=1, start:int=0, stop:Union[int,None]=None, stride:int=10):
def decimate(traj: str, outfile: str, start: int, stop: int, stride: int):
"""
Reduce the number of configurations in a trajectory.
Expand All @@ -35,35 +35,26 @@ def decimate(traj:str, outfile:str, ncpus:int=1, start:int=0, stop:Union[int,Non
Writes the trajectory directly to outfile.
"""
top_info, traj_info = describe(None, traj)
# Describe the trajectory and extract indices
_, traj_info = describe(None, traj)
selected_idxs = traj_info.idxs[start:stop:stride]

#things outside the stop/start/stride bounds just don't exist anymore.
my_di = deepcopy(traj_info)
my_di.idxs = traj_info.idxs[start:stop:stride]
my_di.nconfs = len(my_di.idxs)

ctx = ComputeContext(
my_di,
top_info
)

with open(outfile, 'w+') as f:
def callback(i, r):
nonlocal f
f.write(r)

oat_multiprocesser(my_di.nconfs, ncpus, compute, callback, ctx)

# Pre-extract offsets and sizes to minimize attribute lookups
selected_data = [(idx.offset, idx.size) for idx in selected_idxs]

with open(traj, 'rb') as infile, open(outfile, 'wb') as out_file:
# Memory-map the input file for faster random access
with mmap.mmap(infile.fileno(), 0, access=mmap.ACCESS_READ) as mm:
for offset, size in selected_data:
out_file.write(mm[offset : offset + size])
log(f"Wrote decimated trajectory to {outfile}")
return

def cli_parser(prog="decimate.py"):
parser = argparse.ArgumentParser(prog = prog, description="Creates a smaller trajectory only including start/stop/stride frames from the input.")
parser.add_argument('traj', type=str, help="The trajectory file to decimate")
parser.add_argument('outfile', type=str, help='The name of the new trajectory file to write out')
parser.add_argument('-p', '--parallel', dest='parallel', default=1, type=int, help="(optional) How many cores to use")
parser.add_argument('-s', '--start', dest='start', default=0, type=int, help='First conf to write to the output file.')
parser.add_argument('-e', '--stop', dest='stop', default=None, type=int, help='Process up to this conf (exclusive). Accepts negative indexes.')
parser.add_argument('-e', '--stop', dest='stop', default=-1, type=int, help='Process up to this conf (exclusive). Accepts negative indexes.')
parser.add_argument('-d', '--stride', dest='stride', default=10, type=int, help='Write out every this many confs (default=10)')
parser.add_argument('-q', '--quiet', metavar='quiet', dest='quiet', action='store_const', const=True, default=False, help="Don't print 'INFO' messages to stderr")
return parser
Expand All @@ -80,12 +71,11 @@ def main():
#Parse command line arguments
traj = args.traj
outfile = args.outfile
ncpus = args.parallel
start = args.start
stop = args.stop
stride = args.stride

decimate(traj=traj, outfile=outfile, ncpus=ncpus, start=start, stop=stop, stride=stride)
decimate(traj=traj, outfile=outfile, start=start, stop=stop, stride=stride)
print("--- %s seconds ---" % (time.time() - start_time))

if __name__ == '__main__':
Expand Down

0 comments on commit adbba3c

Please sign in to comment.