[Pytorch] requires_grad=False로 해도 학습이 되는 경우
Updated:
연구를 하던 중에 어떤 파라미터의 업데이트를 중단하기 위해 requires_grad를 False로 설정해도 계속해서 weight가 업데이트 되는 이슈가 있었다.
예를 들어 파라미터 $A$와 파라미터 $B$가 있을 때
전체 $epoch$의 처음 절반은 $A$를 학습하고 이후 절반은 $B$를 학습하고자 할 때
초기 설정을 A.requires_grad = True, B.requires_grad = False로 설정하고,
중간에 다시 A.requires_grad = False, B.requires_grad = True로 설정하였다.
하지만 A가 이후에도 계속해서 업데이트 되는 문제가 발생하였다.
디버깅을 통해 A의 gradient를 출력해보니 None이 아닌 0.0이 출력되었다.
참고링크 를 통해 원인에 대해 명확하게 이해하진 못했지만 해결방안은 찾을 수 있었다.
A.requires_grad = False를 할 때 동시에 A.grad = None으로 설정해주면 해결이 되었다.
why????
생각나는 원인은
- weight_decay에 의한 업데이트
- requires_grad = False로 두어 gradient 추척은 되지 않지만 이전의 gradient가 zero_grad()를 통해 0.0으로 되어 1번과 같은 다른 원인들에 의해 업데이트
정도가 있는 것 같다.
Leave a comment