pytorch-lightning 에 맞도록 celebHQ-Mask-edge 데이터모듈 작성하기
1.
pytorch-lightning + hydra를 이용하면 configuration이 쉽고 여러 모델의 수정이 간편하므로 이 모델을 기준으로 바닥부터 작성해본다.
2.
migration 목표가 되는 코드는 다음과 같다.
3.
가장 먼저, datamodule부터 작성했다.
4.
src/datamodules/datasets/celeba_hq.py
•
기존의 코드를 기반으로, albumentation등 여러 효율적인 library를 적용하여 보다 세련되고 빠르게 (최적화) 만들어주었다.
•
이 데이터셋은 이미지 한 장에 대해 각 클래스별 마스크가 따로 1장씩 부여되어 총 15개의 리스트형태가 GT가 되는 특이한 형태를 띄고 있었다.
•
따라서 dataset.__get_item__()의 출력은 다음과 같다.
input_dict = {'label': label_tensor,
'image': image_tensor,
'path': image_path,
'self_ref': self_ref_flag,
'ref': ref_tensor,
'label_ref': label_ref_tensor
}
YAML
복사
•
다른건 기계적으로 대입해주면 되는데, 주의할 점은 parser로 된 변수를 hydra config로 사용할 수 있도록 __init__(self, var1, var2..): 형식으로 만들어주어야 한다는 점이었다.
•
또 중요한 것은 augmentation이었는데, image를 augmentation하는 그 paramter를 그대로 label 15장, ref 1장, label_ref 15장씩 총 32장에 같은 augmentation을 적용해주어야 한다는 점이다.
이를 위해 Albumentation은 유용한 기능을 제공하고 있었다.
image = np.array(Image.open(image_path).convert('RGB'))
labels = self.get_labels(label_path) #label 15장이 리스트형태로 한번에 들어간다.
transform_data = self.transforms(image=image, masks=labels)
image_tensor = transform_data['image']
label_tensor = transform_data['masks']
Python
복사
labels_ref = self.get_labels(label_path)
transform_data_ref = A.ReplayCompose.replay(transform_data['replay'], image=image, masks=labels_ref)
ref_tensor = transform_data_ref['image']
label_ref_tensor = transform_data_ref['masks']
Python
복사
transform_data 에 사용되었던 augmentation 정보를 그대로 불러올 수 있도록 해준다.
5.
src/datamodules/celeba_hq_datamodule.py
•
ReplayCompose를 사용하기 위해 A.ReplayCompose를 사용해준다.
transform_train = A.ReplayCompose([
A.Resize(width=self.load_size, height=self.load_size, always_apply=True), #, interpolation=inter_image),
A.RandomCrop(width=self.crop_size, height=self.crop_size, always_apply=True),
A.HorizontalFlip(p=0.5),
# A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5), #, interpolation=inter_image),
# A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
ToTensorV2()
])
transform_test = A.ReplayCompose([
A.Resize(width=self.load_size, height=self.load_size),
ToTensorV2()
])
self.data_train = CelebAHQEdgeDataset(transforms=transform_train, is_train=True, **dataset_dicts)
self.data_test = CelebAHQEdgeDataset(transforms=transform_test, is_train=False, **dataset_dicts)
Python
복사
•
A.pytorch.transform.ToTensorV2() 가 작동하지 않는다.
이는 아래와 같이 직접 import해서 해결했다.
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
Python
복사
6.
config/datamodules/celeba_hq_datamodule.yaml
•
hydra 에 사용할 config 파일을 작성한다.
_target_: src.datamodules.celeba_hq_datamodule.CELEBAHQdatamodule
data_dir: ${data_dir_celeba_hq} # data_dir is specified in config.yaml
batch_size: 4
num_workers: 8
pin_memory: True
load_size: 286
crop_size: 256
YAML
복사
7.
Test!
간단한 unit test를 작성해 제대로 불러오고 있는지 확인해보았다.
from argparse import Namespace
import os
from src.datamodules.celeba_hq_datamodule import CELEBAHQdatamodule
import pretty_errors
def test_mnist_datamodule():
celeba_hq_dicts={
'data_dir': 'data/CelebAMask-HQ',
'batch_size': 4,
'num_workers': 8,
'pin_memory': True
}
datamodule = CELEBAHQdatamodule(**celeba_hq_dicts)
datamodule.prepare_data()
assert not datamodule.data_train and not datamodule.data_val and not datamodule.data_test
assert os.path.exists(os.path.join("data", "CelebAMask-HQ"))
assert os.path.exists(os.path.join("data", "CelebAMask-HQ", "CelebA-HQ-img"))
assert os.path.exists(os.path.join("data", "CelebAMask-HQ", "CelebAMask-HQ-mask-anno"))
datamodule.setup()
# assert datamodule.data_train and datamodule.data_val and datamodule.data_test
assert datamodule.data_train and datamodule.data_test
print(f"len(datamodule.data_train): {len(datamodule.data_train)}") # = 24183
print(f"len(datamodule.data_test): {len(datamodule.data_test)}") # = 2993
assert datamodule.train_dataloader()
assert datamodule.val_dataloader()
assert datamodule.test_dataloader()
batch = next(iter(datamodule.train_dataloader()))
print(f'''
batch['label'] | {type(batch['label'])} \t\t| len: {len(batch['label'])}, type: {type(batch['label'][0])}, shape: {batch['label_ref'][0].shape}
batch['image'] | {type(batch['image'])} | shape: {batch['image'].shape}
batch['path'] | {type(batch['path'])} \t\t| len: {len(batch['path'])}, type: {type(batch['path'][0])}
batch['self_ref'] | {type(batch['self_ref'])} | shape: {batch['self_ref'].shape}
batch['ref'] | {type(batch['ref'])} | shape: {batch['ref'].shape}
batch['label_ref'] | {type(batch['label_ref'])} \t\t| len: {len(batch['label_ref'])}, type: {type(batch['label_ref'][0])}, shape: {batch['label_ref'][0].shape}
''')
if __name__ == "__main__":
test_mnist_datamodule()
Python
복사
batch['label'] | <class 'list'> | len: 15, type: <class 'torch.Tensor'>, shape: torch.Size([4, 256, 256])
batch['image'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256])
batch['path'] | <class 'list'> | len: 4, type: <class 'str'>
batch['self_ref'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256])
batch['ref'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256])
batch['label_ref'] | <class 'list'> | len: 15, type: <class 'torch.Tensor'>, shape: torch.Size([4, 256, 256])
Bash
복사
•
출력은 제대로 되는 것 같은데, 이미지로 만들어서 augmentation이 제대로 작동하는지 확인해 보아야 한다.
•
가장 중요한 뼈대는 완성했으니 이제 다른 Dataset (lsun_bedroom, ade20k, deepfashion)은 금방 할 수 있게 되었다.
내일은 본격적으로model을 만들어보자.