Skip to content

Commit bda6c86

Browse files
authored
[binding] download model auto (#1234)
* [binding] download model auto * fix model dir * [binding] fix inline and decouple model download logic * Update decoder.py * [binding] remove necessary blank * [binding] try fix lint * add license
1 parent cf0687b commit bda6c86

File tree

3 files changed

+118
-7
lines changed

3 files changed

+118
-7
lines changed

runtime/binding/python/py/decoder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List
15+
from typing import List, Optional
1616

1717
import _wenet
1818

19+
from .hub import Hub
20+
1921

2022
class Decoder:
23+
2124
def __init__(self,
22-
model_dir: str,
25+
model_dir: Optional[str] = None,
2326
lang: str = 'chs',
2427
nbest: int = 1,
2528
enable_timestamp: bool = False,
26-
context: List[str] = None,
29+
context: Optional[List[str]] = None,
2730
context_score: float = 3.0):
2831
""" Init WeNet decoder
2932
Args:
@@ -34,7 +37,11 @@ def __init__(self,
3437
context: context words
3538
context_score: bonus score when the context is matched
3639
"""
40+
if model_dir is None:
41+
model_dir = Hub.get_model_by_lang(lang)
42+
3743
self.d = _wenet.wenet_init(model_dir)
44+
3845
self.set_language(lang)
3946
self.set_nbest(nbest)
4047
self.enable_timestamp(enable_timestamp)

runtime/binding/python/py/hub.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import tarfile
17+
from pathlib import Path
18+
from urllib.request import urlretrieve
19+
20+
import tqdm
21+
22+
23+
def download(url: str, dest: str, only_child=True):
24+
""" download from url to dest
25+
"""
26+
assert os.path.exists(dest)
27+
28+
def progress_hook(t):
29+
last_b = [0]
30+
31+
def update_to(b=1, bsize=1, tsize=None):
32+
if tsize not in (None, -1):
33+
t.total = tsize
34+
displayed = t.update((b - last_b[0]) * bsize)
35+
last_b[0] = b
36+
return displayed
37+
return update_to
38+
39+
# *.tar.gz
40+
name = url.split("/")[-1]
41+
tar_path = os.path.join(dest, name)
42+
with tqdm.tqdm(unit='B',
43+
unit_scale=True,
44+
unit_divisor=1024,
45+
miniters=1,
46+
desc=(name)) as t:
47+
urlretrieve(url,
48+
filename=tar_path,
49+
reporthook=progress_hook(t),
50+
data=None)
51+
t.total = t.n
52+
53+
with tarfile.open(tar_path) as f:
54+
if not only_child:
55+
f.extractall(dest)
56+
else:
57+
for tarinfo in f:
58+
if "/" not in tarinfo.name:
59+
continue
60+
name = os.path.basename(tarinfo.name)
61+
f.extract(tarinfo, os.path.join(dest, name))
62+
63+
64+
class Hub(object):
65+
"""Hub for wenet pretrain runtime model
66+
"""
67+
# TODO(Mddct): make assets class to support other language
68+
Assets = {
69+
# wenetspeech
70+
"chs":
71+
"https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/20220506_u2pp_conformer_libtorch.tar.gz",
72+
# gigaspeech
73+
"en":
74+
"https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/gigaspeech/20210728_u2pp_conformer_libtorch.tar.gz"
75+
}
76+
77+
def __init__(self) -> None:
78+
pass
79+
80+
@staticmethod
81+
def get_model_by_lang(lang: str) -> str:
82+
assert lang in Hub.Assets.keys()
83+
# NOTE(Mddct): model_dir structure
84+
# Path.Home()/.went
85+
# - chs
86+
# - units.txt
87+
# - final.zip
88+
# - en
89+
# - units.txt
90+
# - final.zip
91+
model_url = Hub.Assets[lang]
92+
model_dir = os.path.join(Path.home(), ".wenet", lang)
93+
if not os.path.exists(model_dir):
94+
os.makedirs(model_dir)
95+
# TODO(Mddct): model metadata
96+
if set(["final.zip",
97+
"units.txt"]).issubset(set(os.listdir(model_dir))):
98+
return model_dir
99+
download(model_url, model_dir, only_child=True)
100+
return model_dir

runtime/binding/python/setup.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ def build_extension(self, ext: setuptools.extension.Extension):
4848
libs = []
4949
torch_lib = 'fc_base/libtorch-src/lib'
5050
for ext in ['so', 'pyd']:
51-
libs.extend(glob.glob(
52-
f"{self.build_temp}/**/_wenet*.{ext}", recursive=True))
51+
libs.extend(
52+
glob.glob(f"{self.build_temp}/**/_wenet*.{ext}",
53+
recursive=True))
5354
for ext in ['so', 'dylib', 'dll']:
54-
libs.extend(glob.glob(
55-
f"{self.build_temp}/**/*wenet_api.{ext}", recursive=True))
55+
libs.extend(
56+
glob.glob(f"{self.build_temp}/**/*wenet_api.{ext}",
57+
recursive=True))
5658
libs.extend(glob.glob(f'{src_dir}/{torch_lib}/*c10.{ext}'))
5759
libs.extend(glob.glob(f'{src_dir}/{torch_lib}/*torch_cpu.{ext}'))
5860

@@ -95,6 +97,8 @@ def read_long_description():
9597
ext_modules=[cmake_extension("_wenet")],
9698
cmdclass={"build_ext": BuildExtension},
9799
zip_safe=False,
100+
setup_requires=["tqdm"],
101+
install_requires=["tqdm"],
98102
classifiers=[
99103
"Programming Language :: C++",
100104
"Programming Language :: Python",

0 commit comments

Comments
 (0)