diff --git a/README.md b/README.md index 874eaeb..e1c3e9d 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ This repository contains all code related to the benchmark, including scripts fo relevant datasets, preprocessing them in a consistent format for benchmark submissions, evaluating any submission outputs, and running baseline models from the original benchmark description. +This repository also contains code for our NAACL paper, [Example-Driven Intent Prediction with Observers](https://arxiv.org/pdf/2010.08684.pdf). + ## Datasets @@ -34,53 +36,53 @@ Upon completion, your DialoGLUE folder should contain the following: ``` dialoglue/ ├── banking -│   ├── categories.json -│   ├── test.csv -│   ├── train_10.csv -│   ├── train_5.csv -│   ├── train.csv -│   └── val.csv +│ ├── categories.json +│ ├── test.csv +│ ├── train_10.csv +│ ├── train_5.csv +│ ├── train.csv +│ └── val.csv ├── clinc -│   ├── categories.json -│   ├── test.csv -│   ├── train_10.csv -│   ├── train_5.csv -│   ├── train.csv -│   └── val.csv +│ ├── categories.json +│ ├── test.csv +│ ├── train_10.csv +│ ├── train_5.csv +│ ├── train.csv +│ └── val.csv ├── dstc8_sgd -│   ├── stats.csv -│   ├── test.json -│   ├── train_10.json -│   ├── train.json -│   ├── val.json -│   └── vocab.txt +│ ├── stats.csv +│ ├── test.json +│ ├── train_10.json +│ ├── train.json +│ ├── val.json +│ └── vocab.txt ├── hwu -│   ├── categories.json -│   ├── test.csv -│   ├── train_10.csv -│   ├── train_5.csv -│   ├── train.csv -│   └── val.csv +│ ├── categories.json +│ ├── test.csv +│ ├── train_10.csv +│ ├── train_5.csv +│ ├── train.csv +│ └── val.csv ├── mlm_all.txt ├── multiwoz -│   ├── MULTIWOZ2.1 -│   │   ├── dialogue_acts.json -│   │   ├── README.txt -│   │   ├── test_dials.json -│   │   ├── train_dials.json -│   │   └── val_dials.json -│   └── MULTIWOZ2.1_fewshot -│   ├── dialogue_acts.json -│   ├── README.txt -│   ├── test_dials.json -│   ├── train_dials.json -│   └── val_dials.json +│ ├── MULTIWOZ2.1 +│ │ ├── dialogue_acts.json +│ │ ├── README.txt +│ │ ├── test_dials.json +│ │ ├── train_dials.json +│ │ └── val_dials.json +│ └── MULTIWOZ2.1_fewshot +│ ├── dialogue_acts.json +│ ├── README.txt +│ ├── test_dials.json +│ ├── train_dials.json +│ └── val_dials.json ├── restaurant8k -│   ├── test.json -│   ├── train_10.json -│   ├── train.json -│   ├── val.json -│   └── vocab.txt +│ ├── test.json +│ ├── train_10.json +│ ├── train.json +│ ├── val.json +│ └── vocab.txt └── top ├── eval.txt ├── test.txt @@ -118,6 +120,8 @@ Almost all of the models can be trained/evaluated using the `run.py` script. Mul The commands for training/evaluating models are as follows. If you want to *only* run inference/evaluation, simply change `--num_epochs` to 0. +To train using example-driven intent prediction, add the `--example` flag to the training script. To use observer nodes, add the `--use_observers` flag. + ### Checkpoints The relevant *convbert* and *convbert-dg* models can be found [here](https://registry.opendata.aws/dialoglue/). @@ -224,3 +228,13 @@ If using these scripts or the DialoGLUE benchmark, please cite the following in } ``` +If you use any code pertaining to example-driven training or observers, please cite the following paper: + +```bash +@article{mehri2020example, + title={Example-Driven Intent Prediction with Observers}, + author={Mehri, Shikib and Eric, Mihail and Hakkani-Tur, Dilek}, + journal={arXiv preprint arXiv:2010.08684}, + year={2020} +} +``` diff --git a/bert_models.py b/bert_models.py index 8e27ebb..b53240d 100644 --- a/bert_models.py +++ b/bert_models.py @@ -24,22 +24,51 @@ class IntentBertModel(torch.nn.Module): def __init__(self, model_name_or_path: str, dropout: float, - num_intent_labels: int): + num_intent_labels: int, + use_observers: bool = False): super(IntentBertModel, self).__init__() self.bert_model = BertModel.from_pretrained(model_name_or_path) self.dropout = Dropout(dropout) self.num_intent_labels = num_intent_labels self.intent_classifier = nn.Linear(self.bert_model.config.hidden_size, num_intent_labels) + self.use_observers = use_observers + self.num_observers = 20 + + def encode(self, + input_ids: torch.tensor, + attention_mask: torch.tensor, + token_type_ids: torch.tensor): + if not self.use_observers: + pooled_output = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[1] + else: + hidden_states = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[0] + + pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1) + + return pooled_output + def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, intent_label: torch.tensor = None): - pooled_output = self.bert_model(input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids)[1] + if not self.use_observers: + pooled_output = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[1] + else: + hidden_states = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[0] + + pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1) + intent_logits = self.intent_classifier(self.dropout(pooled_output)) # Compute losses if labels provided @@ -51,6 +80,71 @@ def forward(self, return intent_logits, intent_loss +class ExampleIntentBertModel(torch.nn.Module): + def __init__(self, + model_name_or_path: str, + dropout: float, + num_intent_labels: int, + use_observers: bool = False): + super(ExampleIntentBertModel, self).__init__() + self.bert_model = BertModel(BertConfig.from_pretrained(model_name_or_path, output_attentions=True)) + + self.dropout = Dropout(dropout) + self.num_intent_labels = num_intent_labels + self.use_observers = use_observers + self.num_observers = 20 + + def encode(self, + input_ids: torch.tensor, + attention_mask: torch.tensor, + token_type_ids: torch.tensor): + if not self.use_observers: + pooled_output = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[1] + else: + hidden_states = self.bert_model(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids)[0] + + pooled_output = hidden_states[:, -self.num_observers:].mean(dim=1) + + return pooled_output + + + def forward(self, + input_ids: torch.tensor, + attention_mask: torch.tensor, + token_type_ids: torch.tensor, + intent_label: torch.tensor, + example_input: torch.tensor, + example_mask: torch.tensor, + example_token_types: torch.tensor, + example_intents: torch.tensor): + example_pooled_output = self.encode(input_ids=example_input, + attention_mask=example_mask, + token_type_ids=example_token_types) + + pooled_output = self.encode(input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids) + + pooled_output = self.dropout(pooled_output) + probs = torch.softmax(pooled_output.mm(example_pooled_output.t()), dim =-1) + + intent_probs = 1e-6 + torch.zeros(probs.size(0), self.num_intent_labels).cuda().scatter_add(-1, example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs) + + # Compute losses if labels provided + if intent_label is not None: + loss_fct = NLLLoss() + intent_lp = torch.log(intent_probs) + intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels), intent_label.type(torch.long)) + else: + intent_loss = torch.tensor(0) + + return intent_probs, intent_loss + + class SlotBertModel(torch.nn.Module): def __init__(self, model_name_or_path: str, diff --git a/run.py b/run.py index 5922345..2906b29 100644 --- a/run.py +++ b/run.py @@ -23,7 +23,7 @@ from constants import SPECIAL_TOKENS from data_readers import IntentDataset, SlotDataset, TOPDataset -from bert_models import BertPretrain, IntentBertModel, JointSlotIntentBertModel, SlotBertModel +from bert_models import BertPretrain, ExampleIntentBertModel, IntentBertModel, JointSlotIntentBertModel, SlotBertModel logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -46,6 +46,8 @@ def read_args(): parser.add_argument("--dump_outputs", action="store_true") parser.add_argument("--mlm_pre", action="store_true") parser.add_argument("--mlm_during", action="store_true") + parser.add_argument("--example", action="store_true") + parser.add_argument("--use_observers", action="store_true") parser.add_argument("--repeat", type=int, default=1) parser.add_argument("--grad_accum", type=int, default=1) parser.add_argument("--train_batch_size", type=int, default=16) @@ -63,14 +65,86 @@ def read_args(): parser.add_argument("--seed", type=int, default=42) return parser.parse_args() + +def retrieve_examples(dataset, labels, inds, task, num=None, cache=defaultdict(list)): + if num is None and labels is not None: + num = len(labels) * 2 + + assert task == "intent", "Example-driven may only be used with intent prediction" + + if len(cache) == 0: + # Populate cache + for i, example in enumerate(dataset): + cache[example['intent_label']].append(i) + + print("Populated example cache.") + + # One example for each label + example_inds = [] + for l in set(labels.tolist()): + if l == -1: + continue + + ind = random.choice(cache[l]) + retries = 0 + while ind in inds.tolist() or type(ind) is not int: + ind = random.choice(cache[l]) + retries += 1 + if retries > len(dataset): + break + + example_inds.append(ind) + + # Sample randomly until we hit batch size + while len(example_inds) < min(len(dataset), num): + ind = random.randint(0, len(dataset) - 1) + if ind not in example_inds and ind not in inds.tolist(): + example_inds.append(ind) + + # Create examples + example_data = [dataset[i] for i in example_inds] + examples = {} + for key in ['input_ids', 'attention_mask', 'token_type_ids']: + examples[key] = torch.stack([torch.LongTensor(e[key]) for e in example_data], dim=0).cuda() + + examples['intent_label'] = torch.LongTensor([e['intent_label'] for e in example_data]).cuda() + + return examples + + def evaluate(model: torch.nn.Module, eval_dataloader: DataLoader, + ex_dataloader: DataLoader, tokenizer: Any, task: str = "intent", example: bool = False, device: int = 0, args: Any = None) -> Tuple[float, float, float]: model.eval() + + bert_output = [] + labels = [] + if example: + assert task == "intent", "Example-Driven may only be used for intent prediction" + + with torch.no_grad(): + for batch in tqdm(ex_dataloader, desc="Building train memory."): + # Move to GPU + if torch.cuda.is_available(): + for key, val in batch.items(): + if type(batch[key]) is list: + continue + + batch[key] = batch[key].to(device) + + pooled_output = model.encode(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]) + + bert_output.append(pooled_output.cpu()) + labels += batch["intent_label"].tolist() + + mem = torch.cat(bert_output, dim=0).cuda() + print("Memory size:", mem.size()) + pred = [] true = [] for batch in tqdm(eval_dataloader, desc="Evaluating"): @@ -84,17 +158,33 @@ def evaluate(model: torch.nn.Module, batch[key] = batch[key].to(device) if task == "intent": - # Forward prop - intent_logits, intent_loss = model(input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"], - intent_label=batch["intent_label"]) + if not example: + # Forward prop + intent_logits, intent_loss = model(input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + intent_label=batch["intent_label"]) + + # Argmax to get predictions + intent_preds = torch.argmax(intent_logits, dim=1).cpu().tolist() + + pred += intent_preds + true += batch["intent_label"].cpu().tolist() + else: + # Encode input + pooled_output = model.encode(batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]) - # Argmax to get predictions - intent_preds = torch.argmax(intent_logits, dim=1).cpu().tolist() + # Probability distribution over examples + probs = torch.softmax(pooled_output.mm(mem.t())[0], dim=-1) - pred += intent_preds - true += batch["intent_label"].cpu().tolist() + # Copy mechanism over training set + intent_probs = torch.zeros(len(ex_dataloader.dataset.intent_idx_to_label)).cuda().scatter_add(0, + torch.LongTensor( + labels).cuda(), + probs) + + pred.append(intent_probs.argmax(dim=-1).item()) + true += batch["intent_label"].cpu().tolist() elif task == "slot": # Forward prop slot_logits, slot_loss = model(input_ids=batch["input_ids"], @@ -106,11 +196,11 @@ def evaluate(model: torch.nn.Module, slot_preds = torch.argmax(slot_logits, dim=2).detach().cpu().numpy() # Generate words, true slots and pred slots - words = [ tokenizer.decode([e]) for e in batch["input_ids"][0].tolist() ] + words = [tokenizer.decode([e]) for e in batch["input_ids"][0].tolist()] actual_gold_slots = batch["slot_labels"].cpu().numpy().squeeze().tolist() - true_slots = [ eval_dataloader.dataset.slot_idx_to_label[s] for s in actual_gold_slots ] + true_slots = [eval_dataloader.dataset.slot_idx_to_label[s] for s in actual_gold_slots] actual_predicted_slots = slot_preds.squeeze().tolist() - pred_slots = [ eval_dataloader.dataset.slot_idx_to_label[s] for s in actual_predicted_slots ] + pred_slots = [eval_dataloader.dataset.slot_idx_to_label[s] for s in actual_predicted_slots] # Find the last turn and only include that. Irrelevant for restaurant8k/dstc8-sgd. if '>' in words: @@ -121,8 +211,8 @@ def evaluate(model: torch.nn.Module, # Filter out words that are padding filt_words = [w for w in words if w not in ['', 'user']] - true_slots = [s for w,s in zip(words, true_slots) if w not in ['', 'user']] - pred_slots = [s for w,s in zip(words, pred_slots) if w not in ['', 'user']] + true_slots = [s for w, s in zip(words, true_slots) if w not in ['', 'user']] + pred_slots = [s for w, s in zip(words, pred_slots) if w not in ['', 'user']] # Convert to slot labels pred.append(pred_slots) @@ -146,8 +236,8 @@ def evaluate(model: torch.nn.Module, # Only unmasked pad_ind = batch["attention_mask"].tolist()[0].index(0) - actual_gold_slots = actual_gold_slots[1:pad_ind-1] - actual_predicted_slots = actual_predicted_slots[1:pad_ind-1] + actual_gold_slots = actual_gold_slots[1:pad_ind - 1] + actual_predicted_slots = actual_predicted_slots[1:pad_ind - 1] # Add to lists pred.append((intent_preds if type(intent_preds) is int else intent_preds[0], actual_predicted_slots)) @@ -160,7 +250,7 @@ def _extract(slot_labels): slots = [] cur_key = None start_ind = -1 - for i,s in enumerate(slot_labels): + for i, s in enumerate(slot_labels): if s == "O" or s == "[PAD]": # Add on-going slot if there is one if cur_key is not None: @@ -178,7 +268,7 @@ def _extract(slot_labels): cur_key = slot_key start_ind = i elif token_type == "I": - # If the slot key doesn't match the currently active, this is invalid. + # If the slot key doesn't match the currently active, this is invalid. # Treat this as an O. if slot_key != cur_key: if cur_key is not None: @@ -199,7 +289,7 @@ def _extract(slot_labels): pred_labels = [eval_dataloader.dataset.intent_idx_to_label.get(p) for p in pred] json.dump(pred_labels, open(args.output_dir + "outputs.json", "w+")) - return sum(p == t for p,t in zip(pred, true))/len(pred) + return sum(p == t for p, t in zip(pred, true)) / len(pred) elif task == "slot": pred_slots = [_extract(e) for e in pred] true_slots = [_extract(e) for e in true] @@ -212,10 +302,10 @@ def _extract(slot_labels): for slot_type in slot_types: predictions_for_slot = [ - [p for p in prediction if slot_type in p] for prediction in pred_slots + [p for p in prediction if slot_type in p] for prediction in pred_slots ] labels_for_slot = [ - [l for l in label if slot_type in l] for label in true_slots + [l for l in label if slot_type in l] for label in true_slots ] proposal_made = [len(p) > 0 for p in predictions_for_slot] @@ -239,17 +329,19 @@ def _extract(slot_labels): return np.mean(slot_type_f1_scores) elif task == "top": if args.dump_outputs: - pred_labels = [(eval_dataloader.dataset.intent_idx_to_label[intent], [eval_dataloader.dataset.slot_idx_to_label[e] for e in slots ]) for intent,slots in pred] + pred_labels = [(eval_dataloader.dataset.intent_idx_to_label[intent], + [eval_dataloader.dataset.slot_idx_to_label[e] for e in slots]) for intent, slots in pred] json.dump(pred_labels, open(args.output_dir + "outputs.json", "w+")) - return sum(p == t for p,t in zip(pred, true))/len(pred) + return sum(p == t for p, t in zip(pred, true)) / len(pred) + def mask_tokens(inputs, tokenizer, mlm_probability=0.15): """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, mlm_probability) - #special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] + # special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] probability_matrix.masked_fill_(torch.tensor(labels == 0, dtype=torch.bool), value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() @@ -267,6 +359,7 @@ def mask_tokens(inputs, tokenizer, mlm_probability=0.15): # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels + def train(args, rep): # Set random seed random.seed(args.seed) @@ -277,10 +370,13 @@ def train(args, rep): if args.output_dir == "": cwd = os.getcwd() base = args.model_name_or_path.split("/")[-1] + model_type = "_example" if args.example else "_linear" data_path = '_' + '_'.join(args.train_data_path.split("/")[-2:]).replace(".csv", "") + mlm_on = "_mlmtrain" if args.mlm_data_path == "" or args.mlm_data_path == args.train_data_path else "_mlmfull" mlm_pre = "_mlmpre" if args.mlm_pre else "" mlm_dur = "_mlmdur" if args.mlm_during else "" - name = base + data_path + mlm_pre + mlm_dur + "_v{}".format(rep) + observer = "_observer" if args.use_observers else "" + name = base + model_type + data_path + mlm_on + mlm_pre + mlm_dur + observer + "_v{}".format(rep) args.output_dir = os.path.join(cwd, "checkpoints", name) if not os.path.exists(args.output_dir): @@ -303,18 +399,18 @@ def train(args, rep): token_vocab_name = os.path.basename(args.token_vocab_path).replace(".txt", "") tokenizer = BertWordPieceTokenizer(args.token_vocab_path, lowercase=args.do_lowercase) - tokenizer.enable_padding(max_length=args.max_seq_length) + tokenizer.enable_padding(length=args.max_seq_length) if args.num_epochs > 0: - tokenizer.save(args.output_dir) + tokenizer.save(args.output_dir) - # Data readers + # Data readers if args.task == "intent": dataset_initializer = IntentDataset elif args.task == "slot": dataset_initializer = SlotDataset elif args.task == "top": - dataset_initializer = TOPDataset + dataset_initializer = TOPDataset else: raise ValueError("Not a valid task type: {}".format(args.task)) @@ -362,9 +458,16 @@ def train(args, rep): # Load model if args.task == "intent": - model = IntentBertModel(args.model_name_or_path, - dropout=args.dropout, - num_intent_labels=len(train_dataset.intent_label_to_idx)) + if args.example: + model = ExampleIntentBertModel(args.model_name_or_path, + dropout=args.dropout, + num_intent_labels=len(train_dataset.intent_label_to_idx), + use_observers=args.use_observers) + else: + model = IntentBertModel(args.model_name_or_path, + dropout=args.dropout, + num_intent_labels=len(train_dataset.intent_label_to_idx), + use_observers=args.use_observers) elif args.task == "slot": model = SlotBertModel(args.model_name_or_path, dropout=args.dropout, @@ -389,7 +492,7 @@ def train(args, rep): # MLM Pre-train if args.mlm_pre and args.num_epochs > 0: - # Maintain most recent score per label. + # Maintain most recent score per label. for epoch in trange(3, desc="Pre-train Epochs"): pre_model.train() epoch_loss = 0 @@ -432,7 +535,7 @@ def train(args, rep): epoch_loss = 0 num_batches = 0 - for batch in tqdm(train_dataloader): + for batch in tqdm(train_dataloader): num_batches += 1 global_step += 1 @@ -446,14 +549,24 @@ def train(args, rep): # Train model if args.task == "intent": - _, intent_loss = model(input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - token_type_ids=batch["token_type_ids"], - intent_label=batch["intent_label"]) - + if args.example: + examples = retrieve_examples(train_dataset, batch["intent_label"], batch["ind"], task="intent") + + _, intent_loss = model(input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + intent_label=batch["intent_label"], + example_input=examples["input_ids"], + example_mask=examples["attention_mask"], + example_token_types=examples["token_type_ids"], + example_intents=examples["intent_label"]) + else: + _, intent_loss = model(input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + intent_label=batch["intent_label"]) if args.grad_accum > 1: intent_loss = intent_loss / args.grad_accum - intent_loss.backward() epoch_loss += intent_loss.item() elif args.task == "slot": @@ -488,7 +601,8 @@ def train(args, rep): LOGGER.info("Epoch loss: {}".format(epoch_loss / num_batches)) # Evaluate and save checkpoint - score = evaluate(model, val_dataloader, tokenizer, task=args.task, device=args.device, args=args) + score = evaluate(model, val_dataloader, train_dataloader, tokenizer, task=args.task, example=args.example, + device=args.device, args=args) metrics_to_log["eval_score"] = score LOGGER.info("Task: {}, score: {}---".format(args.task, score)) @@ -513,7 +627,7 @@ def train(args, rep): break # Run MLM during training - if args.mlm_during: + if args.mlm_during: pre_model.train() epoch_loss = 0 num_batches = 0 @@ -546,12 +660,14 @@ def train(args, rep): # Evaluate on test set LOGGER.info("Loading up best model for test evaluation...") model.load_state_dict(torch.load(os.path.join(args.output_dir, "model.pt"))) - score = evaluate(model, test_dataloader, tokenizer, task=args.task, device=args.device, args=args) + score = evaluate(model, test_dataloader, train_dataloader, tokenizer, task=args.task, example=args.example, + device=args.device, args=args) print("Best result for {}: Score: {}".format(args.task, score)) tb_writer.add_scalar("final_test_score", score, global_step) tb_writer.close() return score + if __name__ == "__main__": args = read_args() print(args) @@ -562,7 +678,7 @@ def train(args, rep): if args.num_epochs > 0: args.output_dir = "" - args.seed = seeds[i] if i < len(seeds) else random.randint(1,999) + args.seed = seeds[i] if i < len(seeds) else random.randint(1, 999) scores.append(train(args, i)) print("Average score so far:", np.mean(scores))