diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index 521ee15a47..1daf4689e4 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -41,6 +41,7 @@ CACHE_PATH = os.path.realpath(os.path.join(os.path.realpath(__file__), '..', 'cached')) if not os.path.exists(CACHE_PATH): os.makedirs(CACHE_PATH, exist_ok=True) +SEPARATORS = 3 def parse_args(): @@ -151,6 +152,10 @@ def parse_args(): 'use --dtype float16, amp will be turned on in the training phase and ' 'fp16 will be used in evaluation.') args = parser.parse_args() + + assert args.doc_stride <= args.max_seq_length - args.max_query_length - SEPARATORS, \ + 'Possible loss of data while chunking input features' + return args @@ -256,7 +261,7 @@ def process_sample(self, feature: SquadFeature): truncated_query_ids = feature.query_token_ids[:self._max_query_length] chunks = feature.get_chunks( doc_stride=self._doc_stride, - max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3) + max_chunk_length=self._max_seq_length - len(truncated_query_ids) - SEPARATORS) for chunk in chunks: data = np.array([self.cls_id] + truncated_query_ids + [self.sep_id] + feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +