Search

21.10.18: pytorch-lightning + Hydra (1) Datamodule 작성하기

pytorch-lightning 에 맞도록 celebHQ-Mask-edge 데이터모듈 작성하기
1.
pytorch-lightning + hydra를 이용하면 configuration이 쉽고 여러 모델의 수정이 간편하므로 이 모델을 기준으로 바닥부터 작성해본다.
2.
migration 목표가 되는 코드는 다음과 같다.
3.
가장 먼저, datamodule부터 작성했다.
4.
src/datamodules/datasets/celeba_hq.py
기존의 코드를 기반으로, albumentation등 여러 효율적인 library를 적용하여 보다 세련되고 빠르게 (최적화) 만들어주었다.
이 데이터셋은 이미지 한 장에 대해 각 클래스별 마스크가 따로 1장씩 부여되어 총 15개의 리스트형태가 GT가 되는 특이한 형태를 띄고 있었다.
따라서 dataset.__get_item__()의 출력은 다음과 같다.
input_dict = {'label': label_tensor, 'image': image_tensor, 'path': image_path, 'self_ref': self_ref_flag, 'ref': ref_tensor, 'label_ref': label_ref_tensor }
YAML
복사
다른건 기계적으로 대입해주면 되는데, 주의할 점은 parser로 된 변수를 hydra config로 사용할 수 있도록 __init__(self, var1, var2..): 형식으로 만들어주어야 한다는 점이었다.
또 중요한 것은 augmentation이었는데, image를 augmentation하는 그 paramter를 그대로 label 15장, ref 1장, label_ref 15장씩 총 32장에 같은 augmentation을 적용해주어야 한다는 점이다.
이를 위해 Albumentation은 유용한 기능을 제공하고 있었다.
image = np.array(Image.open(image_path).convert('RGB')) labels = self.get_labels(label_path) #label 15장이 리스트형태로 한번에 들어간다. transform_data = self.transforms(image=image, masks=labels) image_tensor = transform_data['image'] label_tensor = transform_data['masks']
Python
복사
labels_ref = self.get_labels(label_path) transform_data_ref = A.ReplayCompose.replay(transform_data['replay'], image=image, masks=labels_ref) ref_tensor = transform_data_ref['image'] label_ref_tensor = transform_data_ref['masks']
Python
복사
transform_data 에 사용되었던 augmentation 정보를 그대로 불러올 수 있도록 해준다.
5.
src/datamodules/celeba_hq_datamodule.py
ReplayCompose를 사용하기 위해 A.ReplayCompose를 사용해준다.
transform_train = A.ReplayCompose([ A.Resize(width=self.load_size, height=self.load_size, always_apply=True), #, interpolation=inter_image), A.RandomCrop(width=self.crop_size, height=self.crop_size, always_apply=True), A.HorizontalFlip(p=0.5), # A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5), #, interpolation=inter_image), # A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5), ToTensorV2() ]) transform_test = A.ReplayCompose([ A.Resize(width=self.load_size, height=self.load_size), ToTensorV2() ]) self.data_train = CelebAHQEdgeDataset(transforms=transform_train, is_train=True, **dataset_dicts) self.data_test = CelebAHQEdgeDataset(transforms=transform_test, is_train=False, **dataset_dicts)
Python
복사
A.pytorch.transform.ToTensorV2() 가 작동하지 않는다. 이는 아래와 같이 직접 import해서 해결했다.
import albumentations as A from albumentations.pytorch.transforms import ToTensorV2
Python
복사
6.
config/datamodules/celeba_hq_datamodule.yaml
hydra 에 사용할 config 파일을 작성한다.
_target_: src.datamodules.celeba_hq_datamodule.CELEBAHQdatamodule data_dir: ${data_dir_celeba_hq} # data_dir is specified in config.yaml batch_size: 4 num_workers: 8 pin_memory: True load_size: 286 crop_size: 256
YAML
복사
7.
Test!
간단한 unit test를 작성해 제대로 불러오고 있는지 확인해보았다.
from argparse import Namespace import os from src.datamodules.celeba_hq_datamodule import CELEBAHQdatamodule import pretty_errors def test_mnist_datamodule(): celeba_hq_dicts={ 'data_dir': 'data/CelebAMask-HQ', 'batch_size': 4, 'num_workers': 8, 'pin_memory': True } datamodule = CELEBAHQdatamodule(**celeba_hq_dicts) datamodule.prepare_data() assert not datamodule.data_train and not datamodule.data_val and not datamodule.data_test assert os.path.exists(os.path.join("data", "CelebAMask-HQ")) assert os.path.exists(os.path.join("data", "CelebAMask-HQ", "CelebA-HQ-img")) assert os.path.exists(os.path.join("data", "CelebAMask-HQ", "CelebAMask-HQ-mask-anno")) datamodule.setup() # assert datamodule.data_train and datamodule.data_val and datamodule.data_test assert datamodule.data_train and datamodule.data_test print(f"len(datamodule.data_train): {len(datamodule.data_train)}") # = 24183 print(f"len(datamodule.data_test): {len(datamodule.data_test)}") # = 2993 assert datamodule.train_dataloader() assert datamodule.val_dataloader() assert datamodule.test_dataloader() batch = next(iter(datamodule.train_dataloader())) print(f''' batch['label'] | {type(batch['label'])} \t\t| len: {len(batch['label'])}, type: {type(batch['label'][0])}, shape: {batch['label_ref'][0].shape} batch['image'] | {type(batch['image'])} | shape: {batch['image'].shape} batch['path'] | {type(batch['path'])} \t\t| len: {len(batch['path'])}, type: {type(batch['path'][0])} batch['self_ref'] | {type(batch['self_ref'])} | shape: {batch['self_ref'].shape} batch['ref'] | {type(batch['ref'])} | shape: {batch['ref'].shape} batch['label_ref'] | {type(batch['label_ref'])} \t\t| len: {len(batch['label_ref'])}, type: {type(batch['label_ref'][0])}, shape: {batch['label_ref'][0].shape} ''') if __name__ == "__main__": test_mnist_datamodule()
Python
복사
batch['label'] | <class 'list'> | len: 15, type: <class 'torch.Tensor'>, shape: torch.Size([4, 256, 256]) batch['image'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256]) batch['path'] | <class 'list'> | len: 4, type: <class 'str'> batch['self_ref'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256]) batch['ref'] | <class 'torch.Tensor'> | shape: torch.Size([4, 3, 256, 256]) batch['label_ref'] | <class 'list'> | len: 15, type: <class 'torch.Tensor'>, shape: torch.Size([4, 256, 256])
Bash
복사
출력은 제대로 되는 것 같은데, 이미지로 만들어서 augmentation이 제대로 작동하는지 확인해 보아야 한다.
가장 중요한 뼈대는 완성했으니 이제 다른 Dataset (lsun_bedroom, ade20k, deepfashion)은 금방 할 수 있게 되었다.
내일은 본격적으로model을 만들어보자.