Skip to content

Commit 481fa59

Browse files
authored
Update classifier_classify_new.py
1 parent c49c1d5 commit 481fa59

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/embedding/classifier_classify_new.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
#import judgeutil
4343

4444
BASEDIR = os.getenv('RUNTIME_BASEDIR',os.path.abspath(os.path.dirname(__file__)))
45+
46+
47+
HAS_OPENCL = os.getenv('HAS_OPENCL','true')
4548
sys.path.append(BASEDIR)
4649
import judgeutil
4750

@@ -363,7 +366,10 @@ def train_svm_with_embedding(args_list):
363366
judge_paths = paths
364367
judge_labels = labels
365368
judge_nrof_images = len(judge_paths)
366-
judge_emb_array = np.zeros((judge_nrof_images, 512))
369+
if HAS_OPENCL == 'true:
370+
judge_emb_array = np.zeros((judge_nrof_images, 512))
371+
else:
372+
judge_emb_array = np.zeros((judge_nrof_images, 128))
367373
for j in range(judge_nrof_images):
368374
judge_embedding = None
369375
image_path = judge_paths[j]
@@ -390,7 +396,10 @@ def train_svm_with_embedding(args_list):
390396
nrof_images = len(paths)
391397
nrof_batches_per_epoch = int(math.ceil(1.0 * nrof_images / args.batch_size))
392398

393-
emb_array = np.zeros((nrof_images, 512))
399+
if HAS_OPENCL == 'true':
400+
emb_array = np.zeros((nrof_images, 512))
401+
else:
402+
emb_array = np.zeros((nrof_images, 128))
394403
for i in range(nrof_batches_per_epoch):
395404
start_index = i*args.batch_size
396405
end_index = min((i+1)*args.batch_size, nrof_images)

0 commit comments

Comments
 (0)