Skip to content

Commit 123ade1

Browse files
Add files via upload
1 parent a3fa8e5 commit 123ade1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+370
-0
lines changed

Checkpoints/-47999.index

5.15 KB
Binary file not shown.

Checkpoints/-47999.meta

672 KB
Binary file not shown.

Checkpoints/checkpoint

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
model_checkpoint_path: "-47999"
2+
all_model_checkpoint_paths: "-47199"
3+
all_model_checkpoint_paths: "-47399"
4+
all_model_checkpoint_paths: "-47599"
5+
all_model_checkpoint_paths: "-47799"
6+
all_model_checkpoint_paths: "-47999"

output_small/out_0.png

312 KB

output_small/out_1.png

367 KB

output_small/out_10.png

359 KB

output_small/out_11.png

331 KB

output_small/out_12.png

346 KB

output_small/out_13.png

387 KB

output_small/out_14.png

348 KB

output_small/out_15.png

366 KB

output_small/out_16.png

320 KB

output_small/out_17.png

360 KB

output_small/out_18.png

348 KB

output_small/out_19.png

373 KB

output_small/out_2.png

368 KB

output_small/out_20.png

364 KB

output_small/out_21.png

306 KB

output_small/out_22.png

395 KB

output_small/out_23.png

358 KB

output_small/out_24.png

348 KB

output_small/out_25.png

316 KB

output_small/out_26.png

370 KB

output_small/out_27.png

365 KB

output_small/out_28.png

327 KB

output_small/out_29.png

342 KB

output_small/out_3.png

352 KB

output_small/out_30.png

348 KB

output_small/out_31.png

344 KB

output_small/out_32.png

395 KB

output_small/out_33.png

237 KB

output_small/out_34.png

290 KB

output_small/out_35.png

309 KB

output_small/out_36.png

335 KB

output_small/out_37.png

351 KB

output_small/out_38.png

367 KB

output_small/out_39.png

350 KB

output_small/out_4.png

334 KB

output_small/out_5.png

342 KB

output_small/out_6.png

353 KB

output_small/out_7.png

348 KB

output_small/out_8.png

380 KB

output_small/out_9.png

357 KB

src/conv_helper.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import tensorflow as tf
2+
3+
from utils import *
4+
5+
import tensorflow.contrib.slim as slim
6+
7+
def conv_layer(input_image, ksize, in_channels, out_channels, stride, scope_name, activation_function=lrelu, reuse=False):
8+
with tf.variable_scope(scope_name, reuse=reuse):
9+
filter = tf.Variable(tf.random_normal([ksize, ksize, in_channels, out_channels], stddev=0.03))
10+
output = tf.nn.conv2d(input_image, filter, strides=[1, stride, stride, 1], padding='SAME')
11+
output = slim.batch_norm(output)
12+
if activation_function:
13+
output = activation_function(output)
14+
return output, filter
15+
16+
def resize_deconvolution_layer(input_tensor, new_shape, scope_name):
17+
with tf.variable_scope(scope_name):
18+
output = tf.image.resize_images(input_tensor, (new_shape[1], new_shape[2]), method=1)
19+
output, unused_weights = conv_layer(output, 3, new_shape[3]*2, new_shape[3], 1, scope_name+"_deconv")
20+
return output
21+
22+
def deconvolution_layer(input_tensor, new_shape, scope_name):
23+
return resize_deconvolution_layer(input_tensor, new_shape, scope_name)
24+
25+
def residual_layer(input_image, ksize, in_channels, out_channels, stride, scope_name):
26+
with tf.variable_scope(scope_name):
27+
output, filter = conv_layer(input_image, ksize, in_channels, out_channels, stride, scope_name+"_conv1")
28+
output, filter = conv_layer(output, ksize, out_channels, out_channels, stride, scope_name+"_conv2")
29+
output = tf.add(output, tf.identity(input_image))
30+
return output, filter
31+
32+
def transpose_deconvolution_layer(input_tensor, used_weights, new_shape, stride, scope_name):
33+
with tf.varaible_scope(scope_name):
34+
output = tf.nn.conv2d_transpose(input_tensor, used_weights, output_shape=new_shape, strides=[1, stride, stride, 1], padding='SAME')
35+
output = tf.nn.relu(output)
36+
return output
37+

src/generate.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import cv2
2+
import numpy as np
3+
import glob
4+
import os
5+
6+
img_dir='/media/khushal/679f068d-921b-4d14-890f-3081c1728f98/Rephrase/rephrase_data/test/*'
7+
deg_dir='/media/khushal/679f068d-921b-4d14-890f-3081c1728f98/Rephrase/rephrase_data/degrade_test/'
8+
9+
10+
def degrade(input_path: str, output_path: str) -> None:
11+
"""Load image at `input_path`, distort and save as `output_path`"""
12+
SHIFT = 2
13+
image = cv2.imread(input_path)
14+
orig_img=image
15+
to_swap = np.random.choice([False, True], image.shape[:2], p=[.8, .2])
16+
swap_indices = np.where(to_swap[:-SHIFT] & ~to_swap[SHIFT:])
17+
swap_vals = image[swap_indices[0] + SHIFT, swap_indices[1]]
18+
image[swap_indices[0] + SHIFT, swap_indices[1]] = image[swap_indices]
19+
image[swap_indices] = swap_vals
20+
cv2.imwrite(output_path, image)
21+
22+
23+
imgtypes= glob.glob(img_dir)
24+
base_imgtypes=[os.path.basename(f) for f in imgtypes]
25+
print(base_imgtypes)
26+
for i in range(83):
27+
# os.makedirs(deg_dir+base_imgtypes[i])
28+
class_img=glob.glob(imgtypes[i]+"/*.jpg")
29+
img_name=[os.path.basename(f) for f in class_img]
30+
# print(img_name)
31+
for j in range(len(class_img)):
32+
f=class_img[j]
33+
degrade(f,deg_dir+base_imgtypes[i]+"/"+img_name[j])
34+
# print(len(class_img))
35+

src/model.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import tensorflow.contrib.slim as slim
4+
from utils import *
5+
from conv_helper import *
6+
7+
8+
def generator(input):
9+
conv1, conv1_weights = conv_layer(input, 9, 3, 32, 1, "g_conv1")
10+
conv2, conv2_weights = conv_layer(conv1, 3, 32, 64, 1, "g_conv2")
11+
conv3, conv3_weights = conv_layer(conv2, 3, 64, 128, 1, "g_conv3")
12+
13+
res1, res1_weights = residual_layer(conv3, 3, 128, 128, 1, "g_res1")
14+
# res2, res2_weights = residual_layer(res1, 3, 128, 128, 1, "g_res2")
15+
# res3, res3_weights = residual_layer(res2, 3, 128, 128, 1, "g_res3")
16+
deconv1 = deconvolution_layer(res1, [BATCH_SIZE, 128, 128, 64], 'g_deconv1')
17+
deconv2 = deconvolution_layer(deconv1, [BATCH_SIZE, 256, 256, 32], "g_deconv2")
18+
deconv2 = deconv2 + conv1
19+
conv4, conv4_weights = conv_layer(deconv2, 9, 32, 3, 1, "g_conv5", activation_function=tf.nn.tanh)
20+
conv4 = conv4 + input
21+
return conv4
22+
23+
def discriminator(input, reuse=False):
24+
conv1, conv1_weights = conv_layer(input, 4, 3, 48, 2, "d_conv1", reuse=reuse)
25+
conv2, conv2_weights = conv_layer(conv1, 4, 48, 96, 2, "d_conv2", reuse=reuse)
26+
conv3, conv3_weights = conv_layer(conv2, 4, 96, 192, 2, "d_conv3", reuse=reuse)
27+
conv4, conv4_weights = conv_layer(conv3, 4, 192, 384, 2, "d_conv4", reuse=reuse)
28+
conv5, conv5_weights = conv_layer(conv4, 4, 384, 1, 2, "d_conv5", reuse=reuse)
29+
conv6, conv6_weights = conv_layer(conv5, 4, 1, 1, 2, "d_conv6", reuse=reuse)
30+
avgpool = tf.nn.avg_pool(conv6, ksize=(1,4,4,1), strides=(1,1,1,1), padding='VALID')
31+
out=tf.nn.sigmoid(avgpool)
32+
return out

src/test.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import time
2+
import os
3+
import re
4+
import sys
5+
import glob
6+
import scipy.misc
7+
from itertools import cycle
8+
import cv2
9+
10+
import tensorflow as tf
11+
import numpy as np
12+
13+
from utils import *
14+
from model import *
15+
font = cv2.FONT_HERSHEY_SIMPLEX
16+
bottomLeftCornerOfText = (10,250)
17+
fontScale = 1
18+
fontColor = (255,255,255)
19+
lineType = 2
20+
21+
22+
from skimage import measure
23+
TESTING_SET_DIR= '../rephrase_data/degrade_test/'
24+
GROUNDTRUTH_TEST_SET_DIR='../rephrase_data/test/'
25+
OUT_DIR='../rephrase_data/output_small/'
26+
BATCH_SIZE = 20
27+
def test():
28+
tf.reset_default_graph()
29+
30+
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
31+
gen_in = tf.placeholder(shape=[None, BATCH_SHAPE[1], BATCH_SHAPE[2], BATCH_SHAPE[3]], dtype=tf.float32, name='generated_image')
32+
real_in = tf.placeholder(shape=[None, BATCH_SHAPE[1], BATCH_SHAPE[2], BATCH_SHAPE[3]], dtype=tf.float32, name='groundtruth_image')
33+
Gz = generator(gen_in)
34+
init = tf.global_variables_initializer()
35+
with tf.Session() as sess:
36+
sess.run(init)
37+
saver = initialize(sess)
38+
initial_step = global_step.eval()
39+
start_time = time.time()
40+
initial_step = 0
41+
total_iteration = 2
42+
for index in range(initial_step, total_iteration):
43+
test_batch=next(test_pool)
44+
test_image = (test_batch)*std+mean
45+
ground_test_batch=next(ground_test_pool)
46+
out_image = sess.run(Gz, feed_dict={gen_in: test_batch})
47+
out_image = (out_image)*std+mean
48+
for j in range(20):
49+
img_save=np.hstack((test_image[j],out_image[j],ground_test_batch[j]))
50+
cv2.putText(img_save,'Distorted Generated GroundTruth', bottomLeftCornerOfText, font, fontScale, fontColor, lineType)
51+
cv2.imwrite(OUT_DIR+'out_%d.png' % (20*index+j), img_save)
52+
53+
testing_dataset_init()
54+
testing_truth_dataset_init()
55+
test()

src/train.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import time
2+
3+
import tensorflow as tf
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
from utils import *
7+
from model import *
8+
9+
from skimage import measure
10+
11+
from tensorflow.python.client import device_lib
12+
13+
14+
def train():
15+
tf.reset_default_graph()
16+
17+
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
18+
19+
gen_in = tf.placeholder(shape=[None, BATCH_SHAPE[1], BATCH_SHAPE[2], BATCH_SHAPE[3]], dtype=tf.float32, name='generated_image')
20+
real_in = tf.placeholder(shape=[None, BATCH_SHAPE[1], BATCH_SHAPE[2], BATCH_SHAPE[3]], dtype=tf.float32, name='groundtruth_image')
21+
Gz = generator(gen_in)
22+
Dx = discriminator(real_in)
23+
Dg = discriminator(Gz, reuse=True)
24+
lr_v = tf.Variable(lr_init)
25+
real_in_bgr = tf.map_fn(lambda img: RGB_TO_BGR(img), real_in)
26+
Gz_bgr = tf.map_fn(lambda img: RGB_TO_BGR(img), Gz)
27+
28+
psnr=0
29+
ssim=0
30+
d_loss = -tf.reduce_mean(tf.log(1-Dx) + tf.log(Dg)) * ADVERSARIAL_LOSS_FACTOR
31+
g_loss = ADVERSARIAL_LOSS_FACTOR * -tf.reduce_mean(tf.log(1-Dg)) + PIXEL_LOSS_FACTOR * (get_pixel_loss(real_in, Gz))
32+
t_vars = tf.trainable_variables()
33+
d_vars = [var for var in t_vars if 'd_' in var.name]
34+
g_vars = [var for var in t_vars if 'g_' in var.name]
35+
36+
d_solver = tf.train.AdamOptimizer(4*lr_v).minimize(d_loss, var_list=d_vars, global_step=global_step)
37+
g_solver = tf.train.AdamOptimizer(lr_v).minimize(g_loss, var_list=g_vars)
38+
39+
40+
init = tf.global_variables_initializer()
41+
with tf.Session() as sess:
42+
sess.run(init)
43+
44+
saver = initialize(sess)
45+
initial_step = global_step.eval()
46+
47+
start_time = time.time()
48+
n_batches = 200
49+
total_iteration = n_batches * N_EPOCHS
50+
# print([n.attr['_output_shapes'] for n in tf.get_default_graph().as_graph_def(add_shapes=True).node])
51+
for index in range(initial_step, total_iteration):
52+
training_batch = load_next_training_batch()
53+
groundtruth_batch = load_next_groundtruth_batch()
54+
55+
_, d_loss_cur = sess.run([d_solver, d_loss], feed_dict={gen_in: training_batch, real_in: groundtruth_batch})
56+
_, g_loss_cur = sess.run([g_solver, g_loss], feed_dict={gen_in: training_batch, real_in: groundtruth_batch})
57+
58+
if(index + 1) % SKIP_STEP == 0:
59+
60+
saver.save(sess, CKPT_DIR, index)
61+
image = sess.run(Gz, feed_dict={gen_in: training_batch})
62+
labels = sess.run(Dx, feed_dict={real_in : groundtruth_batch})
63+
labels_2 = sess.run(Dx, feed_dict={real_in : image})
64+
# new_lr_decay = lr_decay**(index // SKIP_STEP)
65+
# lr_v.assign(lr_init * new_lr_decay)
66+
img_save=(image[1]+1)/2
67+
cv2.imwrite(IMG_DIR+'val_%d.png' % (index+1), img_save*255)
68+
print(np.sum((labels>0.5)),np.sum((labels_2<0.5)))
69+
print(
70+
"Step {}/{} Gen Loss: ".format(index + 1, total_iteration) + str(g_loss_cur)+ " Disc Loss: " + str(
71+
d_loss_cur))
72+
73+
74+
75+
if __name__=='__main__':
76+
# print(device_lib.list_local_devices())
77+
training_dataset_init()
78+
groundtruth_dataset_init()
79+
train()

src/utils.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import re
3+
import sys
4+
import glob
5+
import scipy.misc
6+
from itertools import cycle
7+
import cv2
8+
import numpy as np
9+
import tensorflow as tf
10+
lr_init = 0.00001
11+
lr_decay=0.1
12+
BATCH_SIZE = 8
13+
BATCH_SHAPE = [BATCH_SIZE, 256, 256, 3]
14+
SKIP_STEP = 200
15+
N_EPOCHS = 1000
16+
CKPT_DIR = './Checkpoints/'
17+
IMG_DIR = './Images/'
18+
GRAPH_DIR = './Graphs/'
19+
TRAINING_SET_DIR= '../rephrase_data/degrade/'
20+
GROUNDTRUTH_SET_DIR='../rephrase_data/train/'
21+
TESTING_SET_DIR= '../rephrase_data/degrade_test/'
22+
GROUNDTRUTH_TEST_SET_DIR='../rephrase_data/test/'
23+
OUT_DIR='../rephrase_data/output_small/'
24+
ADVERSARIAL_LOSS_FACTOR = 64
25+
PIXEL_LOSS_FACTOR = 0.1
26+
27+
def initialize(sess):
28+
saver = tf.train.Saver()
29+
writer = tf.summary.FileWriter(GRAPH_DIR, sess.graph)
30+
31+
if not os.path.exists(CKPT_DIR):
32+
os.makedirs(CKPT_DIR)
33+
if not os.path.exists(IMG_DIR):
34+
os.makedirs(IMG_DIR)
35+
36+
ckpt = tf.train.get_checkpoint_state(os.path.dirname(CKPT_DIR))
37+
if ckpt and ckpt.model_checkpoint_path:
38+
saver.restore(sess, ckpt.model_checkpoint_path)
39+
return saver
40+
41+
def load_next_training_batch():
42+
batch = next(pool)
43+
return batch
44+
45+
def training_dataset_init():
46+
filelist = glob.glob(TRAINING_SET_DIR + '/*/*.jpg')
47+
batch = np.array([cv2.resize(cv2.imread(fname),(256,256)) for fname in filelist],dtype=float)
48+
print(np.mean(batch,axis=(1,2,3)).shape)
49+
batch-=np.mean(batch,axis=(1,2,3),keepdims=True)
50+
batch/=batch.std(axis=(1,2,3),keepdims=True)
51+
batch = np.array(split(batch, BATCH_SIZE))
52+
global pool
53+
pool = cycle(batch)
54+
55+
56+
def groundtruth_dataset_init():
57+
filelist = glob.glob(GROUNDTRUTH_SET_DIR + '/*/*.jpg')
58+
ground_batch = np.array([cv2.resize(cv2.imread(fname),(256,256)) for fname in filelist],dtype=float)
59+
ground_batch-=np.mean(ground_batch,axis=(1,2,3),keepdims=True)
60+
ground_batch/=ground_batch.std(axis=(1,2,3),keepdims=True)
61+
ground_batch = np.array(split(ground_batch, BATCH_SIZE))
62+
print(ground_batch.shape)
63+
global ground_pool
64+
ground_pool = cycle(ground_batch)
65+
66+
67+
def load_next_groundtruth_batch():
68+
batch = next(ground_pool)
69+
return batch
70+
71+
def testing_dataset_init():
72+
filelist = glob.glob(TESTING_SET_DIR + '/*/*.jpg')
73+
print(len(filelist))
74+
ground_batch = np.array([cv2.resize(cv2.imread(fname),(256,256)) for fname in filelist],dtype=float)
75+
global mean, std
76+
mean=np.mean(ground_batch,axis=(0,1,2,3),keepdims=True)
77+
std=ground_batch.std(axis=(0,1,2,3),keepdims=True)
78+
ground_batch-=mean
79+
ground_batch/=std
80+
print(mean,std)
81+
ground_batch = np.array(split(ground_batch, BATCH_SIZE))
82+
print(ground_batch.shape)
83+
global test_pool
84+
test_pool = cycle(ground_batch)
85+
86+
def testing_truth_dataset_init():
87+
filelist = glob.glob(GROUNDTRUTH_TEST_SET_DIR + '/*/*.jpg')
88+
ground_batch = np.array([cv2.resize(cv2.imread(fname),(256,256)) for fname in filelist],dtype=float)
89+
ground_batch = np.array(split(ground_batch, BATCH_SIZE))
90+
print(ground_batch.shape)
91+
global ground_test_pool
92+
ground_test_pool = cycle(ground_batch)
93+
94+
def tryint(s):
95+
try:
96+
return int(s)
97+
except:
98+
return s
99+
100+
def alphanum_key(s):
101+
""" Turn a string into a list of string and number chunks.
102+
"z23a" -> ["z", 23, "a"]
103+
"""
104+
return [ tryint(c) for c in re.split('([0-9]+)', s) ]
105+
106+
107+
def split(arr, size):
108+
arrs = []
109+
while len(arr) > size:
110+
pice = arr[:size]
111+
arrs.append(pice)
112+
arr = arr[size:]
113+
arrs.append(arr)
114+
return arrs
115+
116+
117+
def lrelu(x, leak=0.2, name='lrelu'):
118+
with tf.variable_scope(name):
119+
return 0.6 * x + 0.4 * abs(x)
120+
121+
def get_pixel_loss(target,prediction):
122+
pixel_difference = target - prediction
123+
pixel_loss = tf.nn.l2_loss(pixel_difference)
124+
return pixel_loss
125+
126+

0 commit comments

Comments
 (0)