Skip to content

Commit

Permalink
Fixed small bug, added progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
1adrianb committed Sep 11, 2017
1 parent 75a5fc9 commit ab9eceb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
44 changes: 24 additions & 20 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require('cunn')
require('cudnn')

-- Load optional data-loading libraries
xrequire('matio') -- matlab
matio = xrequire('matio') -- matlab
npy4th = xrequire('npy4th') -- python numpy

local FaceDetector = require 'facedetection_dlib'
Expand All @@ -22,10 +22,11 @@ torch.setheaptracking(true)
torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(1)

local fileList = utils.getFileList(opts)
local fileList, requireDetectionCnt = utils.getFileList(opts)
local predictions = {}

local faceDetector = FaceDetector()
local faceDetector = nil
if requireDetectionCnt > 0 then faceDetector = FaceDetector() end

local model = torch.load(opts.model)
local modelZ
Expand Down Expand Up @@ -76,31 +77,34 @@ for i = 1, #fileList do
if opts.device ~= 'cpu' then inputZ = inputZ:cuda() end
local depth_pred = modelZ:forward(inputZ):float():view(68,1)
preds_hm = torch.cat(preds_hm, depth_pred, 2)
preds_img = torch.cat(preds_img:view(68,2), depth_pred*(1/(256/(200*fileList[i].scale))),2)
end

if opts.mode == 'demo' then
-- Converting it to the predicted space (for plotting)
detectedFace[{{3,4}}] = utils.transform(torch.Tensor({detectedFace[3],detectedFace[4]}), fileList[i].center, fileList[i].scale, 256)
detectedFace[{{1,2}}] = utils.transform(torch.Tensor({detectedFace[1],detectedFace[2]}), fileList[i].center, fileList[i].scale, 256)
if detectedFace ~= nil then
-- Converting it to the predicted space (for plotting)
detectedFace[{{3,4}}] = utils.transform(torch.Tensor({detectedFace[3],detectedFace[4]}), fileList[i].center, fileList[i].scale, 256)
detectedFace[{{1,2}}] = utils.transform(torch.Tensor({detectedFace[1],detectedFace[2]}), fileList[i].center, fileList[i].scale, 256)

detectedFace[3] = detectedFace[3]-detectedFace[1]
detectedFace[4] = detectedFace[4]-detectedFace[2]

detectedFace[3] = detectedFace[3]-detectedFace[1]
detectedFace[4] = detectedFace[4]-detectedFace[2]
end
utils.plot(img, preds_hm, detectedFace)
end

if opts.save then
local dest = opts.output..'/'..paths.basename(fileList[i].image, '.'..paths.extname(fileList[i].image))
if opts.outputFormat == 't7' then
torch.save(dest..'.t7', preds_img)
elseif opts.outputFormat == 'txt' then
-- csv without header
local out = torch.DiskFile(dest .. '.txt', 'w')
for i=1,68 do
out:writeString(tostring(preds_img[{1,i,1}]) .. ',' .. tostring(preds_img[{1,i,2}]) .. '\n')
end
out:close()
end
local dest = opts.output..'/'..paths.basename(fileList[i].image, '.'..paths.extname(fileList[i].image))
if opts.outputFormat == 't7' then
torch.save(dest..'.t7', preds_img)
elseif opts.outputFormat == 'txt' then
-- csv without header
local out = torch.DiskFile(dest .. '.txt', 'w')
for i=1,68 do
out:writeString(tostring(preds_img[{1,i,1}]) .. ',' .. tostring(preds_img[{1,i,2}]) .. '\n')
end
out:close()
end
xlua.progress(i, #fileList)
end


Expand Down
10 changes: 5 additions & 5 deletions utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ if preds.shape[1]==2:
ax.plot(preds[48:60,0],preds[48:60,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
if detected_face is not None:
if ('detected_face' in vars() or 'detected_face' in globals()) and (detected_face is not None):
ax.add_patch(
patches.Rectangle(
(detected_face[0], detected_face[1]),
Expand All @@ -231,7 +231,7 @@ elif preds.shape[1]==3:
ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.axis('off')
if detected_face is not None:
if ('detected_face' in vars() or 'detected_face' in globals()) and (detected_face is not None):
ax.add_patch(
patches.Rectangle(
(detected_face[0], detected_face[1]),
Expand Down Expand Up @@ -296,8 +296,8 @@ function utils.getFileList(opts)
local requireDetectionCnt = 0
for f in paths.files(data_path, function (file) return file:find('.jpg') or file:find('.png') end) do
-- Check if we have .t7, .mat, .npy or .pts file
local pts = utils.loadUnkownFile(data_path..f:sub(1,#f-4))
local data_pts = {}
local pts = utils.loadUnkownFile(paths.concat(data_path,f:sub(1,#f-4)))
local data_pts = {}

if pts ~= nil then
local center, scale, normby = utils.bounding_box(pts)
Expand Down Expand Up @@ -329,7 +329,7 @@ function utils.getFileList(opts)
end
print('Found '..#filesList..' images')
print(requireDetectionCnt..' images require a face detector')
return filesList
return filesList, requireDetectionCnt
end

function utils.calculateMetrics(dists)
Expand Down

0 comments on commit ab9eceb

Please sign in to comment.