Skip to content

Commit 442a2b0

Browse files
authored
piqa (#1216)
1 parent 961bd57 commit 442a2b0

File tree

6 files changed

+126
-0
lines changed

6 files changed

+126
-0
lines changed

guides/tasks/supported_tasks.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
| MRPC | mrpc ||| mrpc | GLUE |
3333
| Natural Questions | mrqa_natural_questions ||| mrqa_natural_questions | [MRQA](https://mrqa.github.io/) version of task |
3434
| NewsQA | newsqa ||| newsqa | |
35+
| PIQA | piqa ||| piqa | [PIQA](https://yonatanbisk.com/piqa/) |
3536
| QAMR | qamr ||| qamr | |
3637
| QA-SRL | qasrl ||| qasrl | |
3738
| Quoref | quoref ||| quoref | |

jiant/scripts/download_data/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"qasrl",
1414
"newsqa",
1515
"mrqa_natural_questions",
16+
"piqa",
1617
}
1718
DIRECT_DOWNLOAD_TASKS = set(
1819
list(SQUAD_TASKS) + list(DIRECT_SUPERGLUE_TASKS_TO_DATA_URLS) + list(OTHER_DOWNLOAD_TASKS)

jiant/scripts/download_data/dl_datasets/files_tasks.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def download_task_data_and_write_config(task_name: str, task_data_path: str, tas
4848
download_mrqa_natural_questions_data_and_write_config(
4949
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
5050
)
51+
elif task_name == "piqa":
52+
download_piqa_data_and_write_config(
53+
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
54+
)
5155
else:
5256
raise KeyError(task_name)
5357

@@ -590,3 +594,42 @@ def download_mrqa_natural_questions_data_and_write_config(
590594
},
591595
path=task_config_path,
592596
)
597+
598+
599+
def download_piqa_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str):
600+
os.makedirs(task_data_path, exist_ok=True)
601+
download_utils.download_file(
602+
"https://yonatanbisk.com/piqa/data/train.jsonl",
603+
os.path.join(task_data_path, "train.jsonl"),
604+
)
605+
download_utils.download_file(
606+
"https://yonatanbisk.com/piqa/data/train-labels.lst",
607+
os.path.join(task_data_path, "train-labels.lst"),
608+
)
609+
download_utils.download_file(
610+
"https://yonatanbisk.com/piqa/data/valid.jsonl",
611+
os.path.join(task_data_path, "valid.jsonl"),
612+
)
613+
download_utils.download_file(
614+
"https://yonatanbisk.com/piqa/data/valid-labels.lst",
615+
os.path.join(task_data_path, "valid-labels.lst"),
616+
)
617+
download_utils.download_file(
618+
"https://yonatanbisk.com/piqa/data/tests.jsonl",
619+
os.path.join(task_data_path, "tests.jsonl"),
620+
)
621+
622+
py_io.write_json(
623+
data={
624+
"task": task_name,
625+
"paths": {
626+
"train": os.path.join(task_data_path, "train.jsonl"),
627+
"train_labels": os.path.join(task_data_path, "train-labels.lst"),
628+
"val": os.path.join(task_data_path, "valid.jsonl"),
629+
"val_labels": os.path.join(task_data_path, "valid-labels.lst"),
630+
"test": os.path.join(task_data_path, "tests.jsonl"),
631+
},
632+
"name": task_name,
633+
},
634+
path=task_config_path,
635+
)

jiant/tasks/evaluate/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
931931
tasks.XnliTask,
932932
tasks.MCScriptTask,
933933
tasks.ArctTask,
934+
tasks.PiqaTask,
934935
),
935936
):
936937
return SimpleAccuracyEvaluationScheme()

jiant/tasks/lib/piqa.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from dataclasses import dataclass
2+
3+
from jiant.tasks.lib.templates.shared import labels_to_bimap
4+
from jiant.tasks.lib.templates import multiple_choice as mc_template
5+
from jiant.utils.python.io import read_json_lines, read_file_lines
6+
7+
8+
@dataclass
9+
class Example(mc_template.Example):
10+
@property
11+
def task(self):
12+
return PiqaTask
13+
14+
15+
@dataclass
16+
class TokenizedExample(mc_template.TokenizedExample):
17+
pass
18+
19+
20+
@dataclass
21+
class DataRow(mc_template.DataRow):
22+
pass
23+
24+
25+
@dataclass
26+
class Batch(mc_template.Batch):
27+
pass
28+
29+
30+
class PiqaTask(mc_template.AbstractMultipleChoiceTask):
31+
Example = Example
32+
TokenizedExample = Example
33+
DataRow = DataRow
34+
Batch = Batch
35+
36+
CHOICE_KEYS = [0, 1]
37+
CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS)
38+
NUM_CHOICES = len(CHOICE_KEYS)
39+
40+
def get_train_examples(self):
41+
return self._create_examples(
42+
lines=zip(
43+
read_json_lines(self.train_path),
44+
read_file_lines(self.path_dict["train_labels"], strip_lines=True),
45+
),
46+
set_type="train",
47+
)
48+
49+
def get_val_examples(self):
50+
return self._create_examples(
51+
lines=zip(
52+
read_json_lines(self.val_path),
53+
read_file_lines(self.path_dict["val_labels"], strip_lines=True),
54+
),
55+
set_type="val",
56+
)
57+
58+
def get_test_examples(self):
59+
return self._create_examples(
60+
lines=zip(read_json_lines(self.test_path), read_json_lines(self.test_path)),
61+
set_type="test",
62+
)
63+
64+
@classmethod
65+
def _create_examples(cls, lines, set_type):
66+
examples = []
67+
68+
for i, (ex, label_string) in enumerate(lines):
69+
examples.append(
70+
Example(
71+
guid="%s-%s" % (set_type, i),
72+
prompt=ex["goal"],
73+
choice_list=[ex["sol1"], ex["sol2"]],
74+
label=int(label_string) if set_type != "test" else cls.CHOICE_KEYS[-1],
75+
)
76+
)
77+
78+
return examples

jiant/tasks/retrieval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from jiant.tasks.lib.xquad import XquadTask
6868
from jiant.tasks.lib.mcscript import MCScriptTask
6969
from jiant.tasks.lib.arct import ArctTask
70+
from jiant.tasks.lib.piqa import PiqaTask
7071

7172
from jiant.tasks.core import Task
7273
from jiant.utils.python.io import read_json
@@ -139,6 +140,7 @@
139140
"xquad": XquadTask,
140141
"mcscript": MCScriptTask,
141142
"arct": ArctTask,
143+
"piqa": PiqaTask,
142144
}
143145

144146

0 commit comments

Comments
 (0)