Skip to content

Commit

Permalink
Fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Sep 18, 2024
1 parent 9990a49 commit 000e7ec
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
8 changes: 4 additions & 4 deletions examples/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@


class ExampleDataset(Dataset):
def __init__(self, X, y):
self.X = X
def __init__(self, x, y):
self.x = x
self.y = y

def __len__(self):
return len(self.X)
return len(self.x)

def __getitem__(self, idx):
features = torch.tensor(self.X[idx]).to(torch.float32)
features = torch.tensor(self.x[idx]).to(torch.float32)
target = torch.tensor(self.y[idx]).to(torch.float32)
return features, target

Expand Down
2 changes: 1 addition & 1 deletion src/AMSWorkflow/ams/ams_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AMSFluxExecutorFuture(Future):

# NOTE: This is the primary difference of the original FluxExecutorFuture.
# FluxExecutorFuture uses frozensets without adding "memo" and requires all registered
# callbacks to be part of the EVENTS. Thus we cannot inherit directly from it and
# callbacks to be part of the EVENTS.
EVENTS = frozenset(("memo", *list(MAIN_EVENTS)))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/AMSWorkflow/ams/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def from_dict(cls, config):
raise RuntimeError("Config file expects a 'db' entry\n")

for key in {"path", "type"}:
assert key in db, f"Config does not have {k} entry"
assert key in db, f"Config does not have {key} entry"
return cls(config["name"], db["path"], db["type"], db["store"] if "store" in db else None)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions src/AMSWorkflow/ams/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import signal
import time
import os
from queue import Queue
import warnings
from enum import Enum
Expand Down Expand Up @@ -209,6 +210,7 @@ def train_job_spec(self, value):
self._train_job_spec = value
elif isinstance(value, dict):
self._train_job_spec = AMSJob.from_dict(value)
self._train_job_spec.environ = os.environ
else:
raise ValueError("The train job spec expects either a dictionary or a AMSJob type")
return
Expand All @@ -225,6 +227,7 @@ def sub_select_job_spec(self, value):
self._sub_select_job_spec = value
elif isinstance(value, dict):
self._sub_select_job_spec = AMSJob.from_dict(value)
self._sub_select_job_spec.environ = os.environ
else:
raise ValueError("The train job spec expects either a dictionary or a AMSJob type")
return
Expand Down
9 changes: 7 additions & 2 deletions src/AMSWorkflow/ams/wf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def create_domain_list(domains: List[Dict]) -> JobList:
if not all(key in data["db"] for key in {"kosh-path", "name", "store-name"}):
raise KeyError("Workflow description files misses entries in 'db'")

store = AMSDataStore(data["db"]["kosh-path"], data["db"]["store-name"], data["db"]["name"])
store = AMSDataStore(data["db"]["kosh-path"], data["db"]["store-name"], data["db"]["name"]).open()

if "domain-jobs" not in data:
raise KeyError("Workflow description files misses 'domain-jobs' entry")
Expand All @@ -235,6 +235,9 @@ def create_domain_list(domains: List[Dict]) -> JobList:
raise RuntimeError("There are no jobs described in workflow description file")

domain_jobs = create_domain_list(data["domain-jobs"])
ams_rmq_config = AMSRMQConfiguration.from_json(rmq_config)
for job in domain_jobs:
job.precede_deploy(store, ams_rmq_config)

if "stage-job" not in data:
raise RuntimeError("There is no description for a stage-job")
Expand All @@ -254,6 +257,7 @@ def create_domain_list(domains: List[Dict]) -> JobList:
stage_job.environ = os.environ
stage_job.stdout = "stager_test.out"
stage_job.stderr = "stager_test.err"
print("Stager command is:", " ".join(stage_job.generate_cli_command()))
stage_jobs.append(stage_job)

sub_select_jobs = JobList()
Expand All @@ -279,8 +283,9 @@ def create_domain_list(domains: List[Dict]) -> JobList:
for domain in wf_domain_names:
assert domain in train_domains, f"Domain {domain} misses a train description"
assert domain in sub_select_domains, f"Domain {domain} misses a subselection description"

store.close()
store = AMSDataStore(data["db"]["kosh-path"], data["db"]["store-name"], data["db"]["name"])

return cls(
rmq_config,
data["db"]["kosh-path"],
Expand Down

0 comments on commit 000e7ec

Please sign in to comment.