Skip to content

Commit

Permalink
add naacl2021 observers
Browse files Browse the repository at this point in the history
  • Loading branch information
mihail-amazon committed May 25, 2021
1 parent 42737da commit 6836630
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 90 deletions.
96 changes: 55 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This repository contains all code related to the benchmark, including scripts fo
relevant datasets, preprocessing them in a consistent format for benchmark submissions, evaluating any
submission outputs, and running baseline models from the original benchmark description.

This repository also contains code for our NAACL paper, [Example-Driven Intent Prediction with Observers](https://arxiv.org/pdf/2010.08684.pdf).

## Datasets


Expand Down Expand Up @@ -34,53 +36,53 @@ Upon completion, your DialoGLUE folder should contain the following:
```
dialoglue/
├── banking
   ├── categories.json
   ├── test.csv
   ├── train_10.csv
   ├── train_5.csv
   ├── train.csv
   └── val.csv
├── categories.json
├── test.csv
├── train_10.csv
├── train_5.csv
├── train.csv
└── val.csv
├── clinc
   ├── categories.json
   ├── test.csv
   ├── train_10.csv
   ├── train_5.csv
   ├── train.csv
   └── val.csv
├── categories.json
├── test.csv
├── train_10.csv
├── train_5.csv
├── train.csv
└── val.csv
├── dstc8_sgd
   ├── stats.csv
   ├── test.json
   ├── train_10.json
   ├── train.json
   ├── val.json
   └── vocab.txt
├── stats.csv
├── test.json
├── train_10.json
├── train.json
├── val.json
└── vocab.txt
├── hwu
   ├── categories.json
   ├── test.csv
   ├── train_10.csv
   ├── train_5.csv
   ├── train.csv
   └── val.csv
├── categories.json
├── test.csv
├── train_10.csv
├── train_5.csv
├── train.csv
└── val.csv
├── mlm_all.txt
├── multiwoz
   ├── MULTIWOZ2.1
   │   ├── dialogue_acts.json
   │   ├── README.txt
   │   ├── test_dials.json
   │   ├── train_dials.json
   │   └── val_dials.json
   └── MULTIWOZ2.1_fewshot
   ├── dialogue_acts.json
   ├── README.txt
   ├── test_dials.json
   ├── train_dials.json
   └── val_dials.json
├── MULTIWOZ2.1
├── dialogue_acts.json
├── README.txt
├── test_dials.json
├── train_dials.json
└── val_dials.json
└── MULTIWOZ2.1_fewshot
├── dialogue_acts.json
├── README.txt
├── test_dials.json
├── train_dials.json
└── val_dials.json
├── restaurant8k
   ├── test.json
   ├── train_10.json
   ├── train.json
   ├── val.json
   └── vocab.txt
├── test.json
├── train_10.json
├── train.json
├── val.json
└── vocab.txt
└── top
├── eval.txt
├── test.txt
Expand Down Expand Up @@ -118,6 +120,8 @@ Almost all of the models can be trained/evaluated using the `run.py` script. Mul

The commands for training/evaluating models are as follows. If you want to *only* run inference/evaluation, simply change `--num_epochs` to 0.

To train using example-driven intent prediction, add the `--example` flag to the training script. To use observer nodes, add the `--use_observers` flag.

### Checkpoints

The relevant *convbert* and *convbert-dg* models can be found [here](https://registry.opendata.aws/dialoglue/).
Expand Down Expand Up @@ -224,3 +228,13 @@ If using these scripts or the DialoGLUE benchmark, please cite the following in
}
```

If you use any code pertaining to example-driven training or observers, please cite the following paper:

```bash
@article{mehri2020example,
title={Example-Driven Intent Prediction with Observers},
author={Mehri, Shikib and Eric, Mihail and Hakkani-Tur, Dilek},
journal={arXiv preprint arXiv:2010.08684},
year={2020}
}
```
102 changes: 98 additions & 4 deletions bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,51 @@ class IntentBertModel(torch.nn.Module):
def __init__(self,
model_name_or_path: str,
dropout: float,
num_intent_labels: int):
num_intent_labels: int,
use_observers: bool = False):
super(IntentBertModel, self).__init__()
self.bert_model = BertModel.from_pretrained(model_name_or_path)

self.dropout = Dropout(dropout)
self.num_intent_labels = num_intent_labels
self.intent_classifier = nn.Linear(self.bert_model.config.hidden_size, num_intent_labels)
self.use_observers = use_observers
self.num_observers = 20

def encode(self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
token_type_ids: torch.tensor):
if not self.use_observers:
pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[1]
else:
hidden_states = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]

pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1)

return pooled_output


def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
token_type_ids: torch.tensor,
intent_label: torch.tensor = None):
pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[1]
if not self.use_observers:
pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[1]
else:
hidden_states = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]

pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1)

intent_logits = self.intent_classifier(self.dropout(pooled_output))

# Compute losses if labels provided
Expand All @@ -51,6 +80,71 @@ def forward(self,

return intent_logits, intent_loss

class ExampleIntentBertModel(torch.nn.Module):
def __init__(self,
model_name_or_path: str,
dropout: float,
num_intent_labels: int,
use_observers: bool = False):
super(ExampleIntentBertModel, self).__init__()
self.bert_model = BertModel(BertConfig.from_pretrained(model_name_or_path, output_attentions=True))

self.dropout = Dropout(dropout)
self.num_intent_labels = num_intent_labels
self.use_observers = use_observers
self.num_observers = 20

def encode(self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
token_type_ids: torch.tensor):
if not self.use_observers:
pooled_output = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[1]
else:
hidden_states = self.bert_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]

pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1)

return pooled_output


def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor,
token_type_ids: torch.tensor,
intent_label: torch.tensor,
example_input: torch.tensor,
example_mask: torch.tensor,
example_token_types: torch.tensor,
example_intents: torch.tensor):
example_pooled_output = self.encode(input_ids=example_input,
attention_mask=example_mask,
token_type_ids=example_token_types)

pooled_output = self.encode(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)

pooled_output = self.dropout(pooled_output)
probs = torch.softmax(pooled_output.mm(example_pooled_output.t()), dim =-1)

intent_probs = 1e-6 + torch.zeros(probs.size(0), self.num_intent_labels).cuda().scatter_add(-1, example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs)

# Compute losses if labels provided
if intent_label is not None:
loss_fct = NLLLoss()
intent_lp = torch.log(intent_probs)
intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels), intent_label.type(torch.long))
else:
intent_loss = torch.tensor(0)

return intent_probs, intent_loss


class SlotBertModel(torch.nn.Module):
def __init__(self,
model_name_or_path: str,
Expand Down
Loading

0 comments on commit 6836630

Please sign in to comment.