Search

21.11.01: Easy reference vs Hard Reference

Dataset: ADE20k (ADEChallenge2016)
Data Input Pipeline Visualization

pix2pixdataset.py [code]

class Pix2pixDataset(BaseDataset): def __getitem__(self, index): # Label Image label_path = self.label_paths[index] label_tensor, params1 = self.get_label_tensor(label_path) # input image (real images)) image_path = self.image_paths[index] # image_path = 'datasets/ADEChallengeData2016/training/ADE_train_00000012.jpg' if not self.opt.no_pairing_check: assert self.paths_match(label_path, image_path), \ "The label_path %s and image_path %s don't match." % \ (label_path, image_path) image = Image.open(image_path) image = image.convert('RGB') transform_image = get_transform(self.opt, params1) image_tensor = transform_image(image)
Python
복사
image '12.jpg'
image_tensor transformed input tensor
label_tensor
ref_tensor = 0 label_ref_tensor = 0 random_p = random.random() if random_p < self.real_reference_probability or self.opt.phase == 'test': key = image_path.replace('\\', '/').split('DeepFashion/')[-1] if self.opt.dataset_mode == 'deepfashion' else os.path.basename(image_path) # key = 'ADE_train_00000012.jpg' val = self.ref_dict[key] # val = ['ADE_train_00000098.jpg', 'ADE_train_00002690.jpg'] if random_p < self.hard_reference_probability: path_ref = val[1] #hard reference else: path_ref = val[0] #easy reference
Python
복사
Key, '12.jpg'
easy ref, '2690.jpg'
hard ref, '98.jpg'
image_ref = Image.open(path_ref).convert('RGB') if self.opt.dataset_mode != 'deepfashion': path_ref_label = path_ref.replace('.jpg', '.png') path_ref_label = self.imgpath_to_labelpath(path_ref_label) else: path_ref_label = self.imgpath_to_labelpath(path_ref) label_ref_tensor, params = self.get_label_tensor(path_ref_label) transform_image = get_transform(self.opt, params) ref_tensor = transform_image(image_ref) #ref_tensor = self.reference_transform(image_ref) self_ref_flag = torch.zeros_like(ref_tensor) else: pair = False if self.opt.dataset_mode == 'deepfashion' and self.opt.video_like: # if self.opt.hdfs: # key = image_path.split('DeepFashion.zip@/')[-1] # else: # key = image_path.split('DeepFashion/')[-1] key = image_path.replace('\\', '/').split('DeepFashion/')[-1] val = self.ref_dict[key] ref_name = val[0] key_name = key if os.path.dirname(ref_name) == os.path.dirname(key_name) and os.path.basename(ref_name).split('_')[0] == os.path.basename(key_name).split('_')[0]: path_ref = os.path.join(self.opt.dataroot, ref_name) image_ref = Image.open(path_ref).convert('RGB') label_ref_path = self.imgpath_to_labelpath(path_ref) label_ref_tensor, params = self.get_label_tensor(label_ref_path) transform_image = get_transform(self.opt, params) ref_tensor = transform_image(image_ref) pair = True if not pair: label_ref_tensor, params = self.get_label_tensor(label_path) transform_image = get_transform(self.opt, params) ref_tensor = transform_image(image) #ref_tensor = self.reference_transform(image) self_ref_flag = torch.ones_like(ref_tensor)
Python
복사
ref_tensor
label_ref_tensor
input_dict = {'label': label_tensor, 'image': image_tensor, 'path': image_path, 'self_ref': self_ref_flag, 'ref': ref_tensor, 'label_ref': label_ref_tensor }
Python
복사
input_dict
path 'datasets/ADEChallengeData2016/training/ADE_train_00000012.jpg' self_ref_flag torch.ones_like(ref_tensor) [3, 256, 256]
label [1, 256, 256]
image [3, 256, 256]
ref [3, 256, 256]
label_ref [1, 256, 256]
# Give subclasses a chance to modify the final output self.postprocess(input_dict) # Identity return input_dict
Python
복사

pix2pixmodel.py [code]

Generator Step

class Pix2PixModel(torch.nn.Module): ... # Entry point for all calls involving forward pass # of deep networks. We used this approach since DataParallel module # can't parallelize custom functions, we branch to different # routines based on |mode|. def forward(self, data, mode, GforD=None, alpha=1): input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, )
Python
복사
real_image [1, 3, 256, 256]
ref_image [1, 3, 256, 256]
input_label [1, 1, 256, 256] input_semantics [1, 151, 256, 256] (onehot)
ref_label [1, 1, 256, 256] ref_semantics [1, 151, 256, 256] (onehot)
self.alpha = alpha generated_out = {} if mode == 'generator': g_loss, generated_out = self.compute_generator_loss(input_label, input_semantics, real_image, ref_label, ref_semantics, ref_image, self_ref) out = {} out['fake_image'] = generated_out['fake_image'] out['input_semantics'] = input_semantics out['ref_semantics'] = ref_semantics out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out'] out['warp_mask'] = None if 'warp_mask' not in generated_out else generated_out['warp_mask'] out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg'] out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img'] out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle'] out['warp_i2r'] = None if 'warp_i2r' not in generated_out else generated_out['warp_i2r'] out['warp_i2r2i'] = None if 'warp_i2r2i' not in generated_out else generated_out['warp_i2r2i'] return g_loss, out
Python
복사