Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanx749 committed Mar 20, 2024
1 parent 1bcf654 commit 3c07b8b
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
tau=1,
pretrain=True,
posthoc=False,
softmax=False,
):
super().__init__()
self.vocab = vocab
Expand All @@ -141,11 +142,10 @@ def __init__(
n_class = len(cls_num_lst[1:-1])
self.lstm = PaddedLSTM(emb_size, output_size // 2, 1, dropout)
self.transformer_enc = TransformerModel(output_size, 2, hidden_size, 2, dropout)
# self.transformer = PETransformer(emb_size, 2, hidden_size, 4, dropout)
self.attn = DotProductAttention()
self.query = nn.Parameter(torch.randn(1, 1, output_size))
self.clf = Classifier3(
hidden_size, dropout, n_class, cls_num_lst[1:-1], tau, posthoc
hidden_size, dropout, n_class, cls_num_lst[1:-1], tau, posthoc, softmax
)
self.encoder = Encoder(vocab, emb_size, pretrain)

Expand All @@ -162,7 +162,6 @@ def forward(self, input: Tensor = None, **kwargs):
input = self.encoder(input)
x = self.lstm(input, mask)
x = self.transformer_enc(x, mask)
# x = self.transformer(input, mask)
x_s, weights = self.attention(x, mask)
output = self.clf(x_s, x, mask)
output["weights"] = weights
Expand Down Expand Up @@ -198,7 +197,16 @@ def forward(self, x_s: Tensor, x_r: Tensor, mask):


class Classifier3(Classifier):
def __init__(self, hidden_size, dropout, n_class, cls_num_lst, tau, posthoc=False):
def __init__(
self,
hidden_size,
dropout,
n_class,
cls_num_lst,
tau,
posthoc=False,
softmax=False,
):
super().__init__(hidden_size, dropout)
self.mlp2 = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
Expand All @@ -207,8 +215,10 @@ def __init__(self, hidden_size, dropout, n_class, cls_num_lst, tau, posthoc=Fals
nn.Linear(hidden_size, n_class),
)
prior = torch.tensor(cls_num_lst) / sum(cls_num_lst)
# self.adjustment = torch.log(prior)
self.adjustment = torch.log(prior / (1 - prior))
if softmax:
self.adjustment = torch.log(prior)
else:
self.adjustment = torch.log(prior / (1 - prior))
self.tau = tau
self.posthoc = posthoc

Expand All @@ -227,12 +237,14 @@ def forward(self, x_s: Tensor, x_r: Tensor, mask):


class Loss3(nn.Module):
def __init__(self, gamma=0.0, alpha=1.0, ignore_index=-1):
def __init__(self, gamma=0.0, alpha=1.0, ignore_index=-1, softmax=False):
super().__init__()
self.alpha = alpha
self.loss_fn_2 = nn.CrossEntropyLoss(ignore_index=ignore_index)
# self.focal_loss_3 = FocalSoftmaxLoss(gamma=gamma, ignore_index=ignore_index)
self.focal_loss_3 = FocalSigmoidLoss(gamma=gamma, ignore_index=ignore_index)
if softmax:
self.focal_loss_3 = FocalSoftmaxLoss(gamma=gamma, ignore_index=ignore_index)
else:
self.focal_loss_3 = FocalSigmoidLoss(gamma=gamma, ignore_index=ignore_index)

def forward(self, output: Tdict, batch: Tdict):
loss_p = self.loss_fn_2(output["logits_p"], batch["label_p"])
Expand Down

0 comments on commit 3c07b8b

Please sign in to comment.