diff --git a/codes/run.py b/codes/run.py index 457c6fdf..6c1ec67a 100644 --- a/codes/run.py +++ b/codes/run.py @@ -123,7 +123,10 @@ def read_triple(file_path, entity2id, relation2id): triples = [] with open(file_path) as fin: for line in fin: - h, r, t = line.strip().split('\t') + # The entity/relation dict have the element names at the end, when loading them + # the entire line is stripped, removing any trailing spaces. As such, we need + # to strip each element individually here as well. + h, r, t = map(str.strip, line.split('\t')) triples.append((entity2id[h], relation2id[r], entity2id[t])) return triples @@ -160,12 +163,12 @@ def log_metrics(mode, step, metrics): def main(args): if (not args.do_train) and (not args.do_valid) and (not args.do_test): - raise ValueError('one of train/val/test mode must be choosed.') + raise ValueError('one of train/val/test mode must be chosen.') if args.init_checkpoint: override_config(args) elif args.data_path is None: - raise ValueError('one of init_checkpoint/data_path must be choosed.') + raise ValueError('one of init_checkpoint/data_path must be chosen.') if args.do_train and args.save_path is None: raise ValueError('Where do you want to save your trained model?')