논문 베이스라인으로 mipNeRF를 쓰게 되었다.
https://jonbarron.info/mipnerf/
mipNeRF는 기존 NeRF 의 per-pixel ray sampling 을 conical frustrum 을 이용한 sampling 으로 바꿈으로서 multi-scale representation 은 물론 훨씬 가벼운 무게와 빠른 속도까지 가능하게 한 논문이다. 다만 JAX를 썼다는 치명적인 단점이 있지만..
환경 설정하는 과정에서 어려움을 많이 겪었고, 주변에서도 mipNeRF 돌리려다 고생했단 이야기가 건너건너 들리지만, git issue 는 물론 stackoverflow 등등에도 참고할 내용들을 찾을 수 없었다. 그래서 나중에 스스로 참고할 겸, 혹시라도 자료가 없어 나처럼 고생할 사람들을 위해 굉장히 rough한 선에서 정리해 보았다.
(거의 notion 에서 긁어오다시피 한 수준이지만 양해를 부탁드립니다)
Window Setting
>> 결과부터 얘기하자면 실패. 서버에서 실험을 돌리는 동안 디버깅할 생각으로 세팅을 시도했지만, 애초에 build from source 따위는 손대는 게 아니었다. 고생 좀 하다가 우분투로.
Ubuntu Setting
- requirements txt 로 install
- pip intall flax - jax, jaxlib 등 package 다운로드 중 에러 발생
- pip install gin-config 수동 설치
- pip intsall opencv-python 수동 설치 - ≥ 4.4.0 이라지만 대충 깔아도 더 높은 버전 깔 것 같아서 그냥 함
- pip install numpy==1.18.0 적당히 만족하는 버전으로 수동 설치
환경설정 완료 후 디버깅
- flax.nn에 없는 relu , sigmoid 등의 function 을 config 하려함 - 뒤져보니 flax.nn 이 아니라 jax.nn에 있다.
* 후에 알게 된 사실이지만, flax.nn 에 있던 function 들은 그보다 최근 버전에서 flax.linen으로 넘어갔다
- 데이터셋 path 에러 - utils.py 에서 data_dir 및 기타 flag 설정
- GPU support 버전으로 reinstall 하는데, 현재 나와있는 jax, jaxlib 의 latest version 은 cuda support 가 없다. 그리고 현재 서버에 깔린 버전은 0.3.7.
- [서버에 깔린 쿠다 버전 - 10.2 호환되는 gpu jax 는 0.1.71 이 최대. 맞는 버전으로 downgrade
pip install --upgrade jax jaxlib==0.1.71+cuda102 -f <https://storage.googleapis.com/jax-releases/jax_releases.html>
- 이번엔 코드 돌리려면 jax 0.3.2 이상이 필요하다고 함 ㅎ
- jax official site 에 jax gpu 버전은 cuda 11.1 이상부터 support 가 되고, 그보다 구 버전인 경우 build from source 로 알아서 깔아야 된다고 나옴...
- 현 server 는 driver version 이 440 대라 cuda 11.1 대를 깔 수 없고, 내가 맘대로 cuda version 을 바꿀 수 있는 상황도 아니어서 결국 다른 cuda 11.2 로 세팅된 server 로 옮기기로 결정
- CUDA 11.2 cudnn 8.1.1 로 버전 확인
- jaxlib version 은 0.1.65 인가가 맞고, 이 버전에 compatible 한 jax 가 깔려있지 않아서 그런 것 - jax 를 requirements.txt 에서요구하는 최소 version - 0.2.12 로 downgrade 시켜서 해결
jax.nn.initializers.constant?
jax.nn.initializer에 constant funciton을 추가하면 좋다는 github issue 가 있었던 듯
https://github.com/google/jax/issues/7242
해당 코드를 error가 발생하는 src/nn/initializers.py에 추가 하고, jax/nn/initializer 에서도 다르 function 들과 함께 import 하도록
- train_llff.sh 실행시 configuration file 이 제대로 안 들어가는 것 확인 >> config 의 flag 들 수정
- utils.py 의 gin.add_config_file_search_path 도 수정 - 결과적으로 utils 의 common flags 수정으로 들어간 듯.
- 결과적으로 configuration 먹이는데는 성공 - 이번에는 mip-nerf construction 에서 터짐
- jax downgrade - 0.2.26. flax 도 0.4.1로 깔려있었는데 jax 0.3 이상이어야 된다해서 0.3.1 바꿈. 아마도 flax.nn 다시 먹을 듯.
- jax 여전히 jax lib 0.1.65 랑 호환안됨. 더 다운그레이드
- 결국 0.2.12 아니면 호환안됨 ㅋㅋㅋ
- numpy 호환 문제로 1.18.0로 재설치
JAX / Jaxlib / flax / numpy compatibility 최종
- jax: 0.2.12
- jaxlib: 0.1.65+cuda112 (gpu, tpu 호환은 cuda 11.x 밑으로는 안됨)
- numpy 1.18.0 ( 그 밑으로 가면 다른 라이브러리랑 안 맞음)
- flax 0.3.1
결론
>> 안 돌아갈 때 괜히 에러 나는 부분 이거저거 고치지 말고 그냥 library version 만 맞추자.. flax, jax, jaxlib release note 다 켜놓고 하나씩 version 뒤로 돌리다가 성공했다.