Back propgation에서의 전치행렬(transpose matrix) – 1편

실제로 전개해보면 다음 식이 도출됩니다([식 5.13]으로 이끄는 과정은 생략합니다).

– p172, 5.6.1 Affine 계층, 밑바닥부터 시작하는 딥러닝

아니! 그걸 생략하면 어떡해요!!

“밑바닥부터 시작하는 딥러닝”을 읽으면서 딥러닝의 개념을 잡는데 많은 도움을 받고 있지만 굳이 단점을 들자면 주요한 공식 들에 대해 설명하지 않고 그냥 넘어 가버리는 경우가 가끔 있다. 위에서 말하는 [식 5.13]은 back propagataion에서 입력에 대한 loss function의 영향과 weight에 대한 loss function의 영향을 계산하는 다음 식을 의미한다.

\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} \cdot W^T \\\\ \frac{\partial L}{\partial W} = X^Y \cdot \frac{\partial L}{\partial Y}

이 식이 도대체 어떻게 유도된 것인지 이리 저리 찾다가 마침 이 부분을 자세히 설명해 주고 있는 미국 어느 대학(!)의 훌륭한 문서(Backpropagation for a Linear Layer, Justin Johnson, April 19, 2017)를 발견했다. 이 포스팅은 해당 문서에 대한 나름의 이해를 정리한 것이다.

밑밥 깔기

Matrix인 입력 X, Weight W가 있다고 할 때, 이 둘의 dot product인 Y는 다음과 같은 모습이다.

X = \begin{bmatrix}x_{1,1} & x_{1,2}\\x_{2,1} & x_{2,2}\end{bmatrix} W = \begin{bmatrix}w_{1,1} & w_{1,2} & w_{1,3}\\w_{2,1} & w_{2,2} & w_{2,3}\end{bmatrix} Y = X \cdot W = \begin{bmatrix}x_{1,1}w_{1,1} + x_{1,2}w_{2,1} & x_{1,1}w_{1,2} + x_{1,2}w_{2,2} & x_{1,1} w_{1,3} + x_{1,2}w_{2,3} \\\\ x_{2,1}w_{1,1} + x_{2,2}w_{2,1} & x_{2,1}w_{1,2} + x_{2,2}w_{2,2} & x_{2,1}w_{1,3} + x_{2,2}w_{2,3}\end{bmatrix}

Back propagation을 통해 최종으로 구하고자 하는 것은 입력의 변화에 따른 loss function의 변화량 \frac{\partial L}{\partial X}과 Weight 변화에 따른 loss function의 변화량 \frac{\partial L}{\partial W}이다. 이것과 관련해 연쇄 법칙(chain rule)에 따라 이전 layer에서 전달 받은 Y = X \cdot W의 변화에 따른 loss function의 변화량인 \frac{\partial L}{\partial Y}를 고려하면 다음이 성립한다.

\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} \cdot \frac{\partial Y}{\partial X} \\\\ \frac{\partial L}{\partial W} = \frac{\partial L}{\partial Y} \cdot \frac{\partial Y}{\partial W}

Y의 변화에 따른 Loss function의 변화 \frac{\partial L}{\partial Y}

여기에서 \partial L은 scalar 값 이고 Y는 matrix이므로 \frac{\partial L}{\partial Y}의 모습은 다음과 같다.

\begin{bmatrix}\frac{\partial L}{\partial (x_{1,1}w_{1,1} + x_{1,2}w_{2,1})} & \frac {\partial L}{\partial (x_{1,1}w_{1,2} + x_{1,2}w_{2,2})} & \frac {\partial L}{\partial (x_{1,1} w_{1,3} + x_{1,2}w_{2,3})} \\\\ \frac{\partial L}{\partial (x_{2,1}w_{1,1} + x_{2,2}w_{2,1})} & \frac{\partial L}{\partial (x_{2,1}w_{1,2} + x_{2,2}w_{2,2})} & \frac{\partial L}{\partial (x_{2,1}w_{1,3} + x_{2,2}w_{2,3})}\end{bmatrix}

복잡하니까 조금 간단히 다음과 같이 인덱스로 나타내자.

\frac{\partial L}{\partial Y} = \begin{bmatrix}\frac{\partial L}{\partial y_{1,1}} & \frac {\partial L}{\partial y_{1,2}} & \frac {\partial L}{\partial y_{1,3}} \\\\ \frac{\partial L}{\partial y_{2,1}} & \frac{\partial L}{\partial y_{2,2}} & \frac{\partial L}{\partial y_{2,3}}\end{bmatrix}

이제, \frac{\partial Y}{\partial X}\frac{\partial Y}{\partial W}가 남았다.

행렬 X의 원소들에 대한 행렬 Y의 편미분 \frac{\partial Y}{\partial X}

먼저 \frac{\partial Y}{\partial X}를 보면, X와 Y 모두 matrix이니 \frac {\partial Y}{\partial X}는 다음과 같이 생겼다.

\frac{\partial Y}{\partial X} = \begin{bmatrix}\frac{\partial Y}{\partial x_{1,1}} & \frac{\partial Y}{\partial x_{1,2}} & \frac{\partial Y}{\partial x_{1,3}} \\\\ \frac{\partial Y}{\partial x_{2,1}} & \frac{\partial Y}{\partial x_{2,2}} & \frac{\partial Y}{\partial x_{2,3}} \end{bmatrix}

각 원소들은 scalar 값인데 그 중 첫번째 원소인 \frac{\partial Y}{\partial x_{1,1}}를 구하기 위해 Y의 원소들을 x_{1,1}로 편미분 하면 다음과 같이 된다.

\frac{\partial Y}{\partial x_{1,1}}=\begin{bmatrix}w_{1,1} & w_{1,2} & w_{1,3} \\\\ 0 & 0 & 0 \end{bmatrix}

응? 갑자기 이건 뭐냐!

예를 들어 y_{1,1}에 있는 x_{1,1}w_{1,1} + x_{1,2}w_{2,1}x_{1,1}로 편미분하면 w_{1,1}가 되고, x_{1,1}w_{1,2} + x_{1,2}w_{2,2}에 대해서도 같은 방식으로 하면 w_{1,2}가 되는 식으로 Y의 모든 6개의 원소에 적용한 것이다. 이런 짓을 matrix X의 모든 원소인 x_{1, 2}, x_{2,1}, x_{2,2}에 대해서도 모두 구하면 다음과 같이 된다.

\frac{\partial Y}{\partial x_{1,1}}=\begin{bmatrix}w_{1,1} & w_{1,2} & w_{1,3} \\\\ 0 & 0 & 0\end{bmatrix} \\\\ \frac{\partial Y}{\partial x_{1,2}}=\begin{bmatrix}w_{2,1} & w_{2,2} & w_{2,3} \\\\ 0 & 0 & 0 \end{bmatrix} \\\\ \frac{\partial Y}{\partial x_{2,1}}=\begin{bmatrix} 0 & 0 & 0 \\\\ w_{1,1} & w_{1,2} & w_{1,3} \end{bmatrix} \\\\ \frac{\partial Y}{\partial x_{2,2}}=\begin{bmatrix} 0 & 0 & 0 \\\\ w_{2,1} & w_{2,2} & w_{2,3} \end{bmatrix}

행렬 X에 대한 scalar L의 편미분 \frac {\partial L}{\partial X}

\frac{\partial Y}{\partial x_{1,1}}은 matrix X를 구성하는 element인 scalar값이다. 위에서 말한것 처럼 연쇄법칙(Chain rule)에 의해 Y의 모든 원소들에 대하여 다음과 같이 나타낼 수 있다.

\frac{\partial L}{\partial x_{1,1}} = \sum_{i=1}^{N} \sum_{j=1}^{M} \frac{\partial L}{\partial y_{i,j}} \cdot \frac{\partial y_{i,j}}{\partial x_{1,1}}

Matrix X의 첫번째 원소는 {\partial L}를 matrix Y의 각 원소들로 나눈 값들에 Y의 각원소들을 X의 첫번째 원소로 편미분한 값들을 곱한 것을 모두 더한 것이다. 말이 드럽게 복잡해 보이지만, 예를들어, 첫번째 원소인 \frac{\partial L}{\partial x_{1,1}}의 값이 다음과 같이 계산된다는 뜻이다.

\frac{\partial L}{\partial x_{1,1}} = (\frac{\partial L}{\partial y_{1,1}} \times \frac{\partial y_{1,1}}{\partial x_{1,1}}) + (\frac{\partial L}{\partial y_{1,2}} \times \frac{\partial y_{1,2}}{\partial x_{1,1}}) + (\frac{\partial L}{\partial y_{1,3}} \times \frac{\partial y_{1,3}}{\partial x_{1,1}}) + (\frac{\partial L}{\partial y_{2,1}} \times \frac{\partial y_{2,1}}{\partial x_{1,1}}) + (\frac{\partial L}{\partial y_{2,2}} \times \frac{\partial y_{2,2}}{\partial x_{1,1}}) + (\frac{\partial L}{\partial y_{2,3}} \times \frac{\partial y_{2,3}}{\partial x_{1,1}})

이것도 뭐 딱히 깨끗해 보이진 않지만… 여튼, matrix Y의 각원소들에 대해 x_{1,1}로 편미분한 결과를 위 식에 적용해 보면 다음과 같이 된다.

\frac{\partial L}{\partial x_{1,1}} = (\frac{\partial L}{\partial y_{1,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{1,3}) + (\frac{\partial L}{\partial y_{2,1}} \times 0) + (\frac{\partial L}{\partial y_{2,2}} \times 0) + (\frac{\partial L}{\partial y_{2,3}} \times 0) \\\\ = (\frac{\partial L}{\partial y_{1,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{1,3})

같은 방법을 \frac{\partial L}{\partial X}의 모든 원소들에 적용하면

\frac{\partial L}{\partial x_{1,1}} = (\frac{\partial L}{\partial y_{1,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{1,3}) \\\\ \frac{\partial L}{\partial x_{1,2}} = (\frac{\partial L}{\partial y_{1,1}} \times w_{2,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{2,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{2,3}) \\\\ \frac{\partial L}{\partial x_{2,1}} = (\frac{\partial L}{\partial y_{2,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{3,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{3,3}} \times w_{1,3}) \\\\ \frac{\partial L}{\partial x_{2,2}} = (\frac{\partial L}{\partial y_{2,1}} \times w_{2,1}) + (\frac{\partial L}{\partial y_{3,2}} \times w_{2,2}) + (\frac{\partial L}{\partial y_{3,3}} \times w_{2,3})

이것을 matrix의 형태로 나타내면

\frac{\partial L}{\partial X}=\begin{bmatrix}(\frac{\partial L}{\partial y_{1,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{1,3}) & (\frac{\partial L}{\partial y_{1,1}} \times w_{2,1}) + (\frac{\partial L}{\partial y_{1,2}} \times w_{2,2}) + (\frac{\partial L}{\partial y_{1,3}} \times w_{2,3}) \\\\ (\frac{\partial L}{\partial y_{2,1}} \times w_{1,1}) + (\frac{\partial L}{\partial y_{3,2}} \times w_{1,2}) + (\frac{\partial L}{\partial y_{3,3}} \times w_{1,3}) & (\frac{\partial L}{\partial y_{2,1}} \times w_{2,1}) + (\frac{\partial L}{\partial y_{3,2}} \times w_{2,2}) + (\frac{\partial L}{\partial y_{3,3}} \times w_{2,3})\end{bmatrix}

Matrix Y와 W를 구분해 보면

\frac{\partial L}{\partial X}=\begin{bmatrix}\frac{\partial L}{\partial y_{1,1}} & \frac{\partial L}{\partial y_{1,2}} & \frac{\partial L}{\partial y_{1,3}} \\\\ \frac{\partial L}{\partial y_{2,1}} & \frac{\partial L}{\partial y_{2,2}} & \frac{\partial L}{\partial y_{2,3}}\end{bmatrix} \cdot \begin{bmatrix} w_{1,1} & w_{2,1} \\\\ w_{1,2} & w_{2,2} \\\\ w_{1,3} & w_{2,3}\end{bmatrix}

Weight matrix W의 전치행렬(transpose matrix)를 곱하는 것이 되므로,

\frac{\partial Y}{\partial X}=\frac{\partial L}{\partial Y} \cdot W^T이 성립한다.

Weight에 대한 loss function의 변화인 \frac{\partial L}{\partial W} = X^T \cdot \frac{\partial L}{\partial Y}도 기본적으로 같은 방법으로 유도 되는데 너무 길어져서 2편에서 간단히 다루도록 한다.