18
18
import pytest
19
19
import torch
20
20
21
- from open_dubbing .voice_gender_classifier import ( # Assuming this is the file name
22
- VoiceGenderClassifier ,
23
- )
21
+ from open_dubbing .voice_gender_classifier import VoiceGenderClassifier
24
22
25
23
26
24
class TestVoiceGenderClassifier :
25
+ @classmethod
26
+ def setup_class (cls ):
27
+ """Set up the VoiceGenderClassifier model once for all tests."""
28
+ cls .classifier = VoiceGenderClassifier ()
27
29
28
30
@pytest .mark .parametrize (
29
31
"logits_gender, expected_gender" ,
@@ -33,17 +35,14 @@ class TestVoiceGenderClassifier:
33
35
],
34
36
)
35
37
def test_interpret_gender (self , logits_gender , expected_gender ):
36
- classifier = VoiceGenderClassifier ()
37
- predicted_gender = classifier ._interpret_gender (logits_gender )
38
-
38
+ predicted_gender = self .classifier ._interpret_gender (logits_gender )
39
39
assert predicted_gender == expected_gender
40
40
41
41
def test_load_audio_file (self ):
42
42
data_dir = os .path .dirname (os .path .realpath (__file__ ))
43
43
filename = os .path .join (data_dir , "data/this_is_a_test.mp3" )
44
44
45
- classifier = VoiceGenderClassifier ()
46
- samples , target_sampling_rate = classifier .load_audio_file (filename )
45
+ samples , target_sampling_rate = self .classifier .load_audio_file (filename )
47
46
sample_sum = np .sum (samples )
48
47
assert 16000 == target_sampling_rate
49
48
assert np .isclose (sample_sum , - 19.797165 , atol = 2 )
@@ -52,6 +51,5 @@ def test_get_gender_for_file(self):
52
51
data_dir = os .path .dirname (os .path .realpath (__file__ ))
53
52
filename = os .path .join (data_dir , "data/this_is_a_test.mp3" )
54
53
55
- classifier = VoiceGenderClassifier ()
56
- gender = classifier .get_gender_for_file (filename )
54
+ gender = self .classifier .get_gender_for_file (filename )
57
55
assert "Male" == gender
0 commit comments