본문 바로가기

Computer Graphics

mipNeRF / Google JAX 개발 환경 Setting

반응형

논문 베이스라인으로 mipNeRF를 쓰게 되었다. 

https://jonbarron.info/mipnerf/

 

mip-NeRF

Project page for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

jonbarron.info

mipNeRF는 기존 NeRF 의 per-pixel ray sampling 을 conical frustrum 을 이용한 sampling 으로 바꿈으로서 multi-scale representation 은 물론 훨씬 가벼운 무게와 빠른 속도까지 가능하게 한 논문이다. 다만 JAX를 썼다는 치명적인 단점이 있지만..

환경 설정하는 과정에서 어려움을 많이 겪었고, 주변에서도 mipNeRF 돌리려다 고생했단 이야기가 건너건너 들리지만, git issue 는 물론 stackoverflow 등등에도 참고할 내용들을 찾을 수 없었다. 그래서 나중에 스스로 참고할 겸, 혹시라도 자료가 없어 나처럼 고생할 사람들을 위해 굉장히 rough한 선에서 정리해 보았다.

(거의 notion 에서 긁어오다시피 한 수준이지만 양해를 부탁드립니다)

Window Setting

>> 결과부터 얘기하자면 실패. 서버에서 실험을 돌리는 동안 디버깅할 생각으로 세팅을 시도했지만, 애초에 build from source 따위는 손대는 게 아니었다. 고생 좀 하다가 우분투로.

Ubuntu Setting

  • requirements txt 로 install
  • pip intall flax - jax, jaxlib 등 package 다운로드 중 에러 발생

  • pip install gin-config 수동 설치
  • pip intsall opencv-python 수동 설치 - ≥ 4.4.0 이라지만 대충 깔아도 더 높은 버전 깔 것 같아서 그냥 함
  • pip install numpy==1.18.0 적당히 만족하는 버전으로 수동 설치

환경설정 완료 후 디버깅

  • flax.nn에 없는 relu , sigmoid 등의 function 을 config 하려함 - 뒤져보니 flax.nn 이 아니라 jax.nn에 있다. 

* 후에 알게 된 사실이지만, flax.nn 에 있던 function 들은 그보다 최근 버전에서 flax.linen으로 넘어갔다

  • 데이터셋 path 에러 - utils.py 에서 data_dir 및 기타 flag 설정
  • GPU support 버전으로 reinstall 하는데, 현재 나와있는 jax, jaxlib 의 latest version 은 cuda support 가 없다. 그리고 현재 서버에 깔린 버전은 0.3.7.
  • [서버에 깔린 쿠다 버전 - 10.2 호환되는 gpu jax 는 0.1.71 이 최대. 맞는 버전으로 downgrade
pip install --upgrade jax jaxlib==0.1.71+cuda102 -f <https://storage.googleapis.com/jax-releases/jax_releases.html>
  • 이번엔 코드 돌리려면 jax 0.3.2 이상이 필요하다고 함 ㅎ
  • jax official site 에 jax gpu 버전은 cuda 11.1 이상부터 support 가 되고, 그보다 구 버전인 경우 build from source 로 알아서 깔아야 된다고 나옴...
  • 현 server 는 driver version 이 440 대라 cuda 11.1 대를 깔 수 없고, 내가 맘대로 cuda version 을 바꿀 수 있는 상황도 아니어서 결국 다른 cuda 11.2 로 세팅된  server 로 옮기기로 결정
  • CUDA 11.2 cudnn 8.1.1 로 버전 확인
  • jaxlib version 은 0.1.65 인가가 맞고, 이 버전에 compatible 한 jax 가 깔려있지 않아서 그런 것 - jax 를 requirements.txt 에서요구하는 최소 version - 0.2.12 로 downgrade 시켜서 해결

jax.nn.initializers.constant?

jax.nn.initializer에 constant funciton을 추가하면 좋다는 github issue 가 있었던 듯

https://github.com/google/jax/issues/7242

 

jax.nn.initializers.constant implementation · Issue #7242 · google/jax

Please: Check for duplicate requests. Describe your goal, and if possible provide a code snippet with a motivating example. There doesn't seem to be a jax.nn.initializers.constant function. I w...

github.com

해당 코드를 error가 발생하는 src/nn/initializers.py에 추가 하고, jax/nn/initializer 에서도 다르 function 들과 함께 import 하도록

  • train_llff.sh 실행시 configuration file 이 제대로 안 들어가는 것 확인 >> config 의 flag 들 수정
  • utils.py 의 gin.add_config_file_search_path 도 수정 - 결과적으로 utils 의 common flags 수정으로 들어간 듯.
  • 결과적으로 configuration 먹이는데는 성공 - 이번에는 mip-nerf construction 에서 터짐

  • jax downgrade - 0.2.26. flax 도 0.4.1로 깔려있었는데 jax 0.3 이상이어야 된다해서 0.3.1 바꿈. 아마도 flax.nn 다시 먹을 듯.
  • jax 여전히 jax lib 0.1.65 랑 호환안됨. 더 다운그레이드
  • 결국 0.2.12 아니면 호환안됨 ㅋㅋㅋ
  • numpy 호환 문제로 1.18.0로 재설치

JAX / Jaxlib / flax / numpy compatibility 최종

  • jax: 0.2.12
  • jaxlib: 0.1.65+cuda112 (gpu, tpu 호환은 cuda 11.x 밑으로는 안됨)
  • numpy 1.18.0 ( 그 밑으로 가면 다른 라이브러리랑 안 맞음)
  • flax 0.3.1 

결론

>> 안 돌아갈 때 괜히 에러 나는 부분 이거저거 고치지 말고 그냥 library version 만 맞추자.. flax, jax, jaxlib release note 다 켜놓고 하나씩 version 뒤로 돌리다가 성공했다. 

반응형