Search

21.10.20: pytorch-lightning + Hydra (2) argparse → hydra!

hydra를 사용하는 이유는, 여러가지 configuration을 쉽고 빠르게 관리하고 참조할 수 있기 때문이다.
나는 코드 템플릿으로 아래 repo를 사용했다.
이를 사용하면 코드 구조가 다음과 같이 된다.
yaml 형태의 configuration을 모아놓은 configs
코드를 모아놓은 src
이렇게 하면 코드를 한번만 작성해도 yaml로 여러가지 세팅을 지정하여 쉽게 실험을 수행할 수 있다.
더이상 argparse를 조작하는 일은 하지 않아도 되고, 이에 따르는 부수적인 util 코드들도 필요 없어진다.

1. Baseline model의 configuration 을 yaml 화 해주기

먼저 베이스라인이 되는 코드에서 configuration들을 옮겨 올 것이다.
코드가 매우 간단해지고, 이를 참조하는 것도 매우 간단하다.
예를 들어 아래와 같은 yaml을 작성했다고 해보자.
./config/model/my_model.yaml
_target_: src.models.my_model.MyLitModel opt_G: param1: 42 param2: True param3: abc opt_D: param1: 42 param2: False param3: abc
YAML
복사
이 yaml을 이용하여 해당 _target_ 을 instantiate 해주면 다음과 같이 참조할 수 있다.
./src/models/my_model.py
class MyLitModel(LightningModule): def __init__(self, opt_G, opt_D): super().__init__() self.save_hyperparameters() print(opt_G) # Dict 형태로 출력된다. print(opt_G.param1) # Attribute style로 변수에 접근 가능하다. netG = Generator(opt_G) # 클래스에 Dict를 통째로 넣는 것도 된다. netD = Discriminator(param1=opt_D.param1) #물론 이것도 가능하다.
Python
복사
Output
# print(opt_G) {'param1':42, 'param2':True, 'param3': 'abc'} # print(opt_G.param1) 42
Python
복사
매우 간단하다!

2. model 에서 datamodule.yaml 참조하기

1.
datamodule의 __init__()에 아래와 같은 스니펫을 넣어준다.
OmegaConf.register_new_resolver( "datamodule", #resolver name lambda name: getattr(self, name), use_cache=False )
Python
복사
2.
model.yaml 에서 참조하길 원하는 attribute를 지정한다.
# this will get 'datamodule.some_param' field some_parameter: ${datamodule: some_param}
Python
복사

3. 기존 네트워크 코드 불러오기

1.
기존 코드의 뼈대가 되는 부분은 src/models/modules/{network_name}/ 폴더에 넣어준다.
2.
우리는 기존 코드의 trainer.py에 해당하는 코드를 pytorch-lightning으로 바꾸어 줄 것이다. 기존 코드는 cocosnet/ 폴더 아래에 있고, 우리가 refactor하는 코드는 cocosnet_model.py가 된다.
3.
일단 네트워크를 정의하고 불러와봤더니
성공! 네트워크가 잘 정의되었다.
기존의 argparse 가 복잡하게 얽혀있던 코드를 hydra를 사용해서 간단하게 해보았다.
내일은 pytorch-lightning을 이용하여 training scheme을 간단하게 만들어볼 것이다.