-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
49 changed files
with
1,338 additions
and
989 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
.idea/ | ||
*.pyc | ||
*.pyc | ||
build/ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,69 @@ | ||
FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 | ||
MAINTAINER Shintaro Sakoda | ||
|
||
RUN set -x && \ | ||
: "必要なものをインストール" && \ | ||
apt-get update && \ | ||
apt-get install sudo -y && \ | ||
sudo apt-get install git -y && \ | ||
sudo apt-get install vim -y && \ | ||
sudo apt-get install cmake -y && \ | ||
sudo apt-get install python3 -y && \ | ||
sudo apt-get install python3-pip -y && \ | ||
sudo apt-get install p7zip-full -y && \ | ||
sudo apt-get install wget -y && \ | ||
sudo apt-get install curl -y && \ | ||
sudo apt-get install zip -y && \ | ||
sudo apt-get install unzip -y && \ | ||
pip3 install natsort && \ | ||
: "日本語の導入" && \ | ||
sudo apt-get install language-pack-ja-base language-pack-ja -y && \ | ||
echo "export LANG='ja_JP.UTF-8'" >> ~/.bashrc && \ | ||
: "Miacisの取得" && \ | ||
cd ~ && \ | ||
git clone https://github.com/SakodaShintaro/Miacis && \ | ||
: "libtorchの取得" && \ | ||
./Miacis/scripts/download_libtorch.sh && \ | ||
: "ビルド更新スクリプトの準備" && \ | ||
mkdir Miacis/src/cmake-build-release && \ | ||
cd Miacis/src/cmake-build-release && \ | ||
echo "git fetch" > update.sh && \ | ||
FROM nvcr.io/nvidia/pytorch:20.10-py3 | ||
|
||
RUN apt-get update && apt-get install -y curl gnupg && rm -rf /var/lib/apt/lists/* | ||
|
||
RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - && \ | ||
echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list | ||
|
||
RUN apt-get update && apt-get install -y bazel-3.7.1 && rm -rf /var/lib/apt/lists/* | ||
RUN ln -s /usr/bin/bazel-3.7.1 /usr/bin/bazel | ||
|
||
RUN pip install notebook | ||
|
||
# trtorchの導入 | ||
WORKDIR /opt | ||
RUN git clone https://github.com/NVIDIA/TRTorch trtorch | ||
WORKDIR /opt/trtorch | ||
RUN git checkout 721b071f7166e1826183f28305823f406eac4807 | ||
RUN cp /opt/trtorch/docker/WORKSPACE.cu.docker /opt/trtorch/WORKSPACE | ||
|
||
# Workaround for bazel expecting both static and shared versions, we only use shared libraries inside container | ||
RUN cp /usr/lib/x86_64-linux-gnu/libnvinfer.so /usr/lib/x86_64-linux-gnu/libnvinfer_static.a | ||
|
||
WORKDIR /opt/trtorch | ||
RUN bazel build //:libtrtorch --compilation_mode opt | ||
|
||
WORKDIR /opt/trtorch/py | ||
|
||
RUN pip install ipywidgets | ||
RUN jupyter nbextension enable --py widgetsnbextension | ||
|
||
# Locale is not set by default | ||
RUN apt-get update && apt-get install -y locales ninja-build && rm -rf /var/lib/apt/lists/* && locale-gen en_US.UTF-8 | ||
ENV LANG en_US.UTF-8 | ||
ENV LANGUAGE en_US:en | ||
ENV LC_ALL en_US.UTF-8 | ||
RUN python3 setup.py install --use-cxx11-abi | ||
|
||
RUN conda init bash | ||
|
||
ENV LD_LIBRARY_PATH /opt/conda/lib/python3.6/site-packages/torch/lib:$LD_LIBRARY_PATH | ||
|
||
|
||
# ここから自分の設定 | ||
# 言語の設定 | ||
RUN apt-get update && apt-get install -y language-pack-ja-base language-pack-ja && rm -rf /var/lib/apt/lists/* | ||
ENV LANG='ja_JP.UTF-8' | ||
|
||
# 必要なもののインストール | ||
RUN apt-get update && apt-get install -y p7zip-full zip && rm -rf /var/lib/apt/lists/* | ||
RUN pip install natsort | ||
|
||
# trtorchを適切な場所へ展開 | ||
WORKDIR /root | ||
RUN tar xvf /opt/trtorch/bazel-bin/libtrtorch.tar.gz . | ||
|
||
# Miacisの導入 | ||
RUN git clone https://github.com/SakodaShintaro/Miacis | ||
RUN ./Miacis/scripts/download_libtorch.sh | ||
WORKDIR /root/Miacis/src/cmake-build-release | ||
RUN echo "git fetch" > update.sh && \ | ||
echo "git reset --hard origin/master" >> update.sh && \ | ||
echo "cmake -DCMAKE_BUILD_TYPE=Release .." >> update.sh && \ | ||
echo "make -j$(nproc)" >> update.sh && \ | ||
echo "make -j$(nproc) Miacis_shogi_categorical" >> update.sh && \ | ||
chmod +x update.sh && \ | ||
./update.sh && \ | ||
: "dotfilesの取得" && \ | ||
cd ~ && \ | ||
git clone https://github.com/SakodaShintaro/dotfiles && \ | ||
./dotfiles/setup.sh | ||
./update.sh | ||
|
||
# dotfileの導入 | ||
WORKDIR /root | ||
RUN git clone https://github.com/SakodaShintaro/dotfiles && ./dotfiles/setup.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
import glob | ||
import os | ||
import re | ||
from natsort import natsorted | ||
from generate_torch_script_model import * | ||
|
||
|
||
# batch_normがある場合はちょっと特殊なので関数として切り出しておく | ||
def load_conv_and_norm(dst, src): | ||
dst.conv_.weight.data = src.conv_.weight.data | ||
dst.norm_.weight.data = src.norm_.weight.data | ||
dst.norm_.bias.data = src.norm_.bias.data | ||
dst.norm_.running_mean = src.norm_.running_mean | ||
dst.norm_.running_var = src.norm_.running_var | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--source_dir", type=str, required=True) | ||
parser.add_argument("--game", default="shogi", choices=["shogi", "othello"]) | ||
args = parser.parse_args() | ||
|
||
if args.game == "shogi": | ||
input_channel_num = 42 | ||
board_size = 9 | ||
policy_channel_num = 27 | ||
elif args.game == "othello": | ||
input_channel_num = 2 | ||
board_size = 8 | ||
policy_channel_num = 2 | ||
|
||
# ディレクトリにある以下のprefixを持ったパラメータを用いて対局を行う | ||
source_model_names = natsorted(glob.glob(f"{args.source_dir}/*.model")) | ||
|
||
# 1番目のモデル名からブロック数,チャンネル数を読み取る.これらは1ディレクトリ内で共通だという前提 | ||
basename_without_ext = os.path.splitext(os.path.basename(source_model_names[0]))[0] | ||
parts = basename_without_ext.split("_") | ||
block_num = None | ||
channel_num = None | ||
for p in parts: | ||
if "bl" in p: | ||
block_num = int(re.sub("\\D", "", p)) | ||
elif "ch" in p: | ||
channel_num = int(re.sub("\\D", "", p)) | ||
|
||
# インスタンス生成 | ||
model = CategoricalNetwork(input_channel_num, block_num, channel_num, policy_channel_num, board_size) | ||
|
||
# 各モデルファイルのパラメータをコピーしてTorchScriptとして保存 | ||
for source_model_name in source_model_names: | ||
source = torch.jit.load(source_model_name).cpu() | ||
|
||
# first_conv | ||
load_conv_and_norm(model.encoder_.first_conv_and_norm_, source.state_first_conv_and_norm_) | ||
|
||
# block | ||
for i, v in enumerate(model.encoder_.__dict__["_modules"]["blocks"]): | ||
source_m = source.__dict__["_modules"][f"state_blocks_{i}"] | ||
load_conv_and_norm(v.conv_and_norm0_, source_m.conv_and_norm0_) | ||
load_conv_and_norm(v.conv_and_norm1_, source_m.conv_and_norm1_) | ||
v.linear0_.weight.data = source_m.linear0_.weight.data | ||
v.linear1_.weight.data = source_m.linear1_.weight.data | ||
|
||
# policy_conv | ||
model.policy_head_.policy_conv_.weight.data = source.policy_conv_.weight.data | ||
model.policy_head_.policy_conv_.bias.data = source.policy_conv_.bias.data | ||
|
||
# value_conv_norm_ | ||
load_conv_and_norm(model.value_head_.value_conv_and_norm_, source.value_conv_and_norm_) | ||
|
||
# value_linear | ||
model.value_head_.value_linear0_.weight.data = source.value_linear0_.weight.data | ||
model.value_head_.value_linear0_.bias.data = source.value_linear0_.bias.data | ||
model.value_head_.value_linear1_.weight.data = source.value_linear1_.weight.data | ||
model.value_head_.value_linear1_.bias.data = source.value_linear1_.bias.data | ||
|
||
input_data = torch.ones([1, input_channel_num, board_size, board_size]) | ||
model.eval() | ||
script_model = torch.jit.trace(model, input_data) | ||
# script_model = torch.jit.script(model) | ||
model_path = f"{args.game}_{os.path.basename(source_model_name)}" | ||
script_model.save(model_path) | ||
print(f"{model_path}にパラメータを保存") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,26 @@ | ||
echo "\$1(1番目の引数): $1" | ||
|
||
# どこに保存するかの基準位置($0 = ./の2つ上がMiacisと同階層なのでそこに置く) | ||
root_dir=$(dirname "$0")/../../data | ||
|
||
# 棋譜のダウンロード | ||
download_path=${root_dir}/floodgate_kifu | ||
mkdir -p "${download_path}" | ||
wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2015.7z" | ||
wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2016.7z" | ||
wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2017.7z" | ||
wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2018.7z" | ||
wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2019.7z" | ||
# wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2016.7z" | ||
# wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2017.7z" | ||
# wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2018.7z" | ||
# wget -P "${download_path}" "http://wdoor.c.u-tokyo.ac.jp/shogi/x/wdoor2019.7z" | ||
|
||
# 学習用データ(2016年以降) | ||
train_path=${download_path}/train | ||
mkdir -p "${train_path}" | ||
7z e "${download_path}"/wdoor2016.7z -o"${train_path}" | ||
7z e "${download_path}"/wdoor2017.7z -o"${train_path}" | ||
7z e "${download_path}"/wdoor2018.7z -o"${train_path}" | ||
7z e "${download_path}"/wdoor2019.7z -o"${train_path}" | ||
# train_path=${download_path}/train | ||
# mkdir -p "${train_path}" | ||
# 7z e "${download_path}"/wdoor2016.7z -o"${train_path}" | ||
# 7z e "${download_path}"/wdoor2017.7z -o"${train_path}" | ||
# 7z e "${download_path}"/wdoor2018.7z -o"${train_path}" | ||
# 7z e "${download_path}"/wdoor2019.7z -o"${train_path}" | ||
|
||
# 検証用データ(2015年) | ||
valid_path=${download_path}/valid | ||
mkdir -p "${valid_path}" | ||
7z e "${download_path}"/wdoor2015.7z -o"${valid_path}" | ||
7z e "${download_path}"/wdoor2015.7z -o"${valid_path}" |
Oops, something went wrong.