•
Dataset: ADE20k (ADEChallenge2016)
•
Data Input Pipeline Visualization
pix2pixdataset.py [code]
class Pix2pixDataset(BaseDataset):
def __getitem__(self, index):
# Label Image
label_path = self.label_paths[index]
label_tensor, params1 = self.get_label_tensor(label_path)
# input image (real images))
image_path = self.image_paths[index] # image_path = 'datasets/ADEChallengeData2016/training/ADE_train_00000012.jpg'
if not self.opt.no_pairing_check:
assert self.paths_match(label_path, image_path), \
"The label_path %s and image_path %s don't match." % \
(label_path, image_path)
image = Image.open(image_path)
image = image.convert('RGB')
transform_image = get_transform(self.opt, params1)
image_tensor = transform_image(image)
Python
복사
image '12.jpg'
image_tensor
transformed input tensor
label_tensor
ref_tensor = 0
label_ref_tensor = 0
random_p = random.random()
if random_p < self.real_reference_probability or self.opt.phase == 'test':
key = image_path.replace('\\', '/').split('DeepFashion/')[-1] if self.opt.dataset_mode == 'deepfashion' else os.path.basename(image_path) # key = 'ADE_train_00000012.jpg'
val = self.ref_dict[key] # val = ['ADE_train_00000098.jpg', 'ADE_train_00002690.jpg']
if random_p < self.hard_reference_probability:
path_ref = val[1] #hard reference
else:
path_ref = val[0] #easy reference
Python
복사
Key, '12.jpg'
easy ref, '2690.jpg'
hard ref, '98.jpg'
image_ref = Image.open(path_ref).convert('RGB')
if self.opt.dataset_mode != 'deepfashion':
path_ref_label = path_ref.replace('.jpg', '.png')
path_ref_label = self.imgpath_to_labelpath(path_ref_label)
else:
path_ref_label = self.imgpath_to_labelpath(path_ref)
label_ref_tensor, params = self.get_label_tensor(path_ref_label)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image_ref)
#ref_tensor = self.reference_transform(image_ref)
self_ref_flag = torch.zeros_like(ref_tensor)
else:
pair = False
if self.opt.dataset_mode == 'deepfashion' and self.opt.video_like:
# if self.opt.hdfs:
# key = image_path.split('DeepFashion.zip@/')[-1]
# else:
# key = image_path.split('DeepFashion/')[-1]
key = image_path.replace('\\', '/').split('DeepFashion/')[-1]
val = self.ref_dict[key]
ref_name = val[0]
key_name = key
if os.path.dirname(ref_name) == os.path.dirname(key_name) and os.path.basename(ref_name).split('_')[0] == os.path.basename(key_name).split('_')[0]:
path_ref = os.path.join(self.opt.dataroot, ref_name)
image_ref = Image.open(path_ref).convert('RGB')
label_ref_path = self.imgpath_to_labelpath(path_ref)
label_ref_tensor, params = self.get_label_tensor(label_ref_path)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image_ref)
pair = True
if not pair:
label_ref_tensor, params = self.get_label_tensor(label_path)
transform_image = get_transform(self.opt, params)
ref_tensor = transform_image(image)
#ref_tensor = self.reference_transform(image)
self_ref_flag = torch.ones_like(ref_tensor)
Python
복사
ref_tensor
label_ref_tensor
input_dict = {'label': label_tensor,
'image': image_tensor,
'path': image_path,
'self_ref': self_ref_flag,
'ref': ref_tensor,
'label_ref': label_ref_tensor
}
Python
복사
input_dict
path 'datasets/ADEChallengeData2016/training/ADE_train_00000012.jpg'
self_ref_flag torch.ones_like(ref_tensor) [3, 256, 256]
label
[1, 256, 256]
image
[3, 256, 256]
ref
[3, 256, 256]
label_ref
[1, 256, 256]
# Give subclasses a chance to modify the final output
self.postprocess(input_dict) # Identity
return input_dict
Python
복사
pix2pixmodel.py [code]
Generator Step
class Pix2PixModel(torch.nn.Module):
...
# Entry point for all calls involving forward pass
# of deep networks. We used this approach since DataParallel module
# can't parallelize custom functions, we branch to different
# routines based on |mode|.
def forward(self, data, mode, GforD=None, alpha=1):
input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, )
Python
복사
real_image [1, 3, 256, 256]
ref_image [1, 3, 256, 256]
input_label [1, 1, 256, 256]
input_semantics [1, 151, 256, 256] (onehot)
ref_label [1, 1, 256, 256]
ref_semantics [1, 151, 256, 256] (onehot)
self.alpha = alpha
generated_out = {}
if mode == 'generator':
g_loss, generated_out = self.compute_generator_loss(input_label,
input_semantics, real_image, ref_label, ref_semantics, ref_image, self_ref)
out = {}
out['fake_image'] = generated_out['fake_image']
out['input_semantics'] = input_semantics
out['ref_semantics'] = ref_semantics
out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out']
out['warp_mask'] = None if 'warp_mask' not in generated_out else generated_out['warp_mask']
out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg']
out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img']
out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle']
out['warp_i2r'] = None if 'warp_i2r' not in generated_out else generated_out['warp_i2r']
out['warp_i2r2i'] = None if 'warp_i2r2i' not in generated_out else generated_out['warp_i2r2i']
return g_loss, out
Python
복사