diff --git a/main.lua b/main.lua index 80ffbe2..33e43eb 100644 --- a/main.lua +++ b/main.lua @@ -13,7 +13,7 @@ require('cunn') require('cudnn') -- Load optional data-loading libraries -xrequire('matio') -- matlab +matio = xrequire('matio') -- matlab npy4th = xrequire('npy4th') -- python numpy local FaceDetector = require 'facedetection_dlib' @@ -22,10 +22,11 @@ torch.setheaptracking(true) torch.setdefaulttensortype('torch.FloatTensor') torch.setnumthreads(1) -local fileList = utils.getFileList(opts) +local fileList, requireDetectionCnt = utils.getFileList(opts) local predictions = {} -local faceDetector = FaceDetector() +local faceDetector = nil +if requireDetectionCnt > 0 then faceDetector = FaceDetector() end local model = torch.load(opts.model) local modelZ @@ -76,31 +77,34 @@ for i = 1, #fileList do if opts.device ~= 'cpu' then inputZ = inputZ:cuda() end local depth_pred = modelZ:forward(inputZ):float():view(68,1) preds_hm = torch.cat(preds_hm, depth_pred, 2) + preds_img = torch.cat(preds_img:view(68,2), depth_pred*(1/(256/(200*fileList[i].scale))),2) end if opts.mode == 'demo' then - -- Converting it to the predicted space (for plotting) - detectedFace[{{3,4}}] = utils.transform(torch.Tensor({detectedFace[3],detectedFace[4]}), fileList[i].center, fileList[i].scale, 256) - detectedFace[{{1,2}}] = utils.transform(torch.Tensor({detectedFace[1],detectedFace[2]}), fileList[i].center, fileList[i].scale, 256) + if detectedFace ~= nil then + -- Converting it to the predicted space (for plotting) + detectedFace[{{3,4}}] = utils.transform(torch.Tensor({detectedFace[3],detectedFace[4]}), fileList[i].center, fileList[i].scale, 256) + detectedFace[{{1,2}}] = utils.transform(torch.Tensor({detectedFace[1],detectedFace[2]}), fileList[i].center, fileList[i].scale, 256) - detectedFace[3] = detectedFace[3]-detectedFace[1] - detectedFace[4] = detectedFace[4]-detectedFace[2] - + detectedFace[3] = detectedFace[3]-detectedFace[1] + detectedFace[4] = detectedFace[4]-detectedFace[2] + end utils.plot(img, preds_hm, detectedFace) end if opts.save then - local dest = opts.output..'/'..paths.basename(fileList[i].image, '.'..paths.extname(fileList[i].image)) - if opts.outputFormat == 't7' then - torch.save(dest..'.t7', preds_img) - elseif opts.outputFormat == 'txt' then - -- csv without header - local out = torch.DiskFile(dest .. '.txt', 'w') - for i=1,68 do - out:writeString(tostring(preds_img[{1,i,1}]) .. ',' .. tostring(preds_img[{1,i,2}]) .. '\n') - end - out:close() - end + local dest = opts.output..'/'..paths.basename(fileList[i].image, '.'..paths.extname(fileList[i].image)) + if opts.outputFormat == 't7' then + torch.save(dest..'.t7', preds_img) + elseif opts.outputFormat == 'txt' then + -- csv without header + local out = torch.DiskFile(dest .. '.txt', 'w') + for i=1,68 do + out:writeString(tostring(preds_img[{1,i,1}]) .. ',' .. tostring(preds_img[{1,i,2}]) .. '\n') + end + out:close() + end + xlua.progress(i, #fileList) end diff --git a/utils.lua b/utils.lua index 1768f7b..d2c3923 100644 --- a/utils.lua +++ b/utils.lua @@ -204,7 +204,7 @@ if preds.shape[1]==2: ax.plot(preds[48:60,0],preds[48:60,1],marker='o',markersize=6,linestyle='-',color='w',lw=2) ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2) - if detected_face is not None: + if ('detected_face' in vars() or 'detected_face' in globals()) and (detected_face is not None): ax.add_patch( patches.Rectangle( (detected_face[0], detected_face[1]), @@ -231,7 +231,7 @@ elif preds.shape[1]==3: ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2) ax.axis('off') - if detected_face is not None: + if ('detected_face' in vars() or 'detected_face' in globals()) and (detected_face is not None): ax.add_patch( patches.Rectangle( (detected_face[0], detected_face[1]), @@ -296,8 +296,8 @@ function utils.getFileList(opts) local requireDetectionCnt = 0 for f in paths.files(data_path, function (file) return file:find('.jpg') or file:find('.png') end) do -- Check if we have .t7, .mat, .npy or .pts file - local pts = utils.loadUnkownFile(data_path..f:sub(1,#f-4)) - local data_pts = {} + local pts = utils.loadUnkownFile(paths.concat(data_path,f:sub(1,#f-4))) + local data_pts = {} if pts ~= nil then local center, scale, normby = utils.bounding_box(pts) @@ -329,7 +329,7 @@ function utils.getFileList(opts) end print('Found '..#filesList..' images') print(requireDetectionCnt..' images require a face detector') - return filesList + return filesList, requireDetectionCnt end function utils.calculateMetrics(dists)