-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
64 changed files
with
5,839 additions
and
3,268 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
format: | ||
poetry run ruff format --config pyproject.toml | ||
poetry run ruff check --fix --config pyproject.toml | ||
lint: | ||
poetry run ruff format --check --config pyproject.toml | ||
poetry run ruff check --config pyproject.toml | ||
poetry run mypy --config-file pyproject.toml . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import os | ||
import random | ||
import sys | ||
from argparse import ArgumentParser | ||
from typing import Any, Dict | ||
|
||
import hydra | ||
import jsonlines | ||
import pandas as pd # type: ignore[import-untyped] | ||
import wandb | ||
from dotenv import load_dotenv | ||
from hydra import compose, initialize | ||
from omegaconf import OmegaConf | ||
from tqdm import tqdm # type: ignore[import-untyped] | ||
|
||
from configs import BaselineConfig | ||
from src import CMGBackbone, CMGBaseline, CMGMetrics | ||
|
||
load_dotenv() | ||
|
||
root = logging.getLogger() | ||
root.setLevel(logging.INFO) | ||
handler = logging.StreamHandler(sys.stdout) | ||
formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] - %(message)s") | ||
handler.setFormatter(formatter) | ||
root.addHandler(handler) | ||
|
||
|
||
def init_baseline(cfg: BaselineConfig) -> CMGBaseline: | ||
# init backbone | ||
backbone: CMGBackbone = hydra.utils.instantiate(cfg.backbone) | ||
|
||
# init preprocessor | ||
preprocessor = hydra.utils.instantiate( | ||
cfg.preprocessor, model_name=cfg.backbone.model_name, model_provider=backbone.name | ||
) | ||
|
||
return CMGBaseline(backbone=backbone, preprocessor=preprocessor) | ||
|
||
|
||
async def get_predictions( | ||
baseline: CMGBaseline, cfg: BaselineConfig, predictions_path: str = "predictions.jsonl" | ||
) -> str: | ||
# init iterator (either over local file or over HuggingFace dataset) | ||
if hasattr(cfg.data_src, "path"): | ||
cfg.data_src.path = hydra.utils.to_absolute_path(cfg.data_src.path) # type: ignore[attr-defined] | ||
reader = hydra.utils.instantiate(cfg.data_src) | ||
|
||
async def _get_prediction(line: Dict[str, Any]) -> None: | ||
baseline_output = await baseline.agenerate_msg(commit=line) # type: ignore[arg-type] | ||
cur_example = {"reference": line["message"], "hash": line["hash"], "repo": line["repo"]} | ||
cur_example.update(baseline_output) | ||
|
||
with jsonlines.open(predictions_path, "a") as writer: | ||
writer.write(cur_example) | ||
|
||
return None | ||
|
||
# get predictions for all input examples | ||
open(predictions_path, "w").close() | ||
tasks = [_get_prediction(line) for line in reader] | ||
await asyncio.gather(*tasks) | ||
return predictions_path | ||
|
||
|
||
def compute_metrics(predictions_path: str) -> Dict[str, float]: | ||
metrics = CMGMetrics() | ||
with jsonlines.open(predictions_path, "r") as reader: | ||
for example in tqdm(reader, desc="Computing metrics"): | ||
metrics.update(predictions=[example["prediction"]], references=[example["reference"]]) | ||
computed_metrics = metrics.compute() | ||
print("=== METRICS ===") | ||
print(computed_metrics) | ||
return computed_metrics | ||
|
||
|
||
async def main(config_name: str) -> None: | ||
initialize(version_base="1.1", config_path="configs/async") | ||
cfg_dict = compose(config_name=config_name) | ||
cfg = BaselineConfig(**cfg_dict) # type: ignore | ||
|
||
os.makedirs(f"results/{config_name[: -len('.yaml')]}", exist_ok=True) | ||
|
||
if hasattr(cfg.backbone, "seed") and cfg.backbone.seed is None: | ||
cfg.backbone.seed = random.randint(1, 2**32) | ||
logging.warning(f"Using random seed {cfg.backbone.seed}.") | ||
|
||
# init W&B (optional) | ||
if cfg.logger.use_wandb: | ||
wandb.init( | ||
project=cfg.logger.project, | ||
name=cfg.logger.name, | ||
config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[arg-type] | ||
job_type="eval", | ||
) | ||
|
||
# init baseline | ||
baseline = init_baseline(cfg) | ||
|
||
# obtain predictions | ||
predictions_path = await get_predictions( | ||
cfg=cfg, baseline=baseline, predictions_path=f"results/{config_name[: -len('.yaml')]}/predictions.jsonl" | ||
) | ||
|
||
# log predictions to W&B (optional) | ||
if cfg.logger.use_wandb: | ||
artifact = wandb.Artifact( | ||
f"{cfg.backbone.model_name.replace('/', '__')}_{cfg.preprocessor._target_.split('.')[-1]}_{cfg.logger.name + '_' if cfg.logger.name else ''}predictions", | ||
type="dataset", | ||
) | ||
if cfg.logger.local_artifact: | ||
artifact.add_reference(f"file:///{os.path.abspath(predictions_path)}") | ||
else: | ||
test_table = wandb.Table(dataframe=pd.read_json(predictions_path, orient="records", lines=True)) | ||
artifact.add(test_table, "predictions") | ||
wandb.log_artifact(artifact) | ||
|
||
# compute metrics | ||
computed_metrics = compute_metrics(predictions_path) | ||
with open(f"results/{config_name[: -len('.yaml')]}/metrics.json", "w") as f: | ||
json.dump(computed_metrics, f) | ||
|
||
# log metrics to W&B (optional) | ||
if cfg.logger.use_wandb: | ||
wandb.log(computed_metrics) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser( | ||
description="Launch a commit message generation model for Long Code Arena dataset asynchronously." | ||
) | ||
parser.add_argument( | ||
"--config-name", type=str, help="Which config under `configs/async` directory to use.", required=True | ||
) | ||
args = parser.parse_args() | ||
|
||
asyncio.run(main(args.config_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPrompt | ||
model_name: deepseek-ai/DeepSeek-R1 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.SimpleCMGPreprocessor | ||
include_path: true | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-16k-files.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptForFullFiles | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "full_files" | ||
hf_repo_split: "16k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-16k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptWContext | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "retrieval_bm25" | ||
hf_repo_split: "16k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-32k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptWContext | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "retrieval_bm25" | ||
hf_repo_split: "32k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-4k-files.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptForFullFiles | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "full_files" | ||
hf_repo_split: "4k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-4k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptWContext | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "retrieval_bm25" | ||
hf_repo_split: "4k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
26 changes: 26 additions & 0 deletions
26
commit_message_generation/configs/async/DeepSeek-V3-64k.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
backbone: | ||
_target_: src.backbones.TogetherBackbone | ||
prompt: | ||
_target_: src.prompts.DetailedCMGPromptWContext | ||
model_name: deepseek-ai/DeepSeek-V3 | ||
api_key: null | ||
parameters: | ||
temperature: 0.8 | ||
preprocessor: | ||
_target_: src.preprocessors.LoadFromDatasetPreprocessor | ||
include_path: true | ||
hf_repo_id: "JetBrains-Research/lca-commit-message-generation" | ||
hf_repo_config: "retrieval_bm25" | ||
hf_repo_split: "64k" | ||
logger: | ||
use_wandb: false | ||
name: null | ||
project: null | ||
local_artifact: null | ||
data_src: | ||
_target_: src.data_sources.HFDataSource | ||
cache_dir: null | ||
hub_name: JetBrains-Research/lca-commit-message-generation | ||
configs: | ||
- default | ||
split: test |
Oops, something went wrong.