아인슈타인 표기법과 einsum 함수
아인슈타인 표기법(Einstein notation)은 첨자가 있는 값들의 합을 편리하게 나타내는 표기 방법이다.
이 표기법에 따라 계산하는 함수인 einsum
은 NumPy, PyTorch 등의 파이썬 라이브러리에서 지원한다. 이 함수를 사용하여 복잡한 연산을 간결하게 나타낼 수 있으며 상황에 따라 성능도 향상시킬 수 있다.
아인슈타인 표기법
아인슈타인 표기법을 도입한 것은 그 유명한 알베르트 아인슈타인(Albert Einstein)이다. 아인슈타인은 일반 상대성 이론에서 나오는 수많은 변수의 합을 간단히 표현하기 위해 이 표기법을 도입하였다. 이 표기법은 주로 물리학에서 사용되다가, 최근에는 머신러닝 등에서 복잡한 텐서 연산을 표현하기 위해 활용되기도 한다.
정확하게 말하면, 여기서 설명할 아인슈타인 표기법은 물리학에서 사용되는 것과 약간 차이가 있다. 여기서는 NumPy 등 고차원 텐서를 사용하는 라이브러리의 einsum
함수에서 사용되는 표기법을 다룰 것이다.
아인슈타인 표기법의 핵심은 합 기호($\sum$)를 제거하고 이것 없이도 원소들이 어떻게 더해지는지 추론할 수 있도록 하는 것이다. 다음 예를 살펴보자. $A$는 $l \times m$ 크기의 행렬, $B$는 $m \times l$ 크기의 행렬이라고 하자. 이때 $C = AB$는
\[C_{ik} = \sum_{j=1}^m A_{ij} B_{jk}\]로 표현된다. ($A_{ij}$는 $A$의 $(i, j)$번째 성분을 의미한다.)
이 식의 우변을 살펴보자. $j$는 합 기호에서 사용되는 첨자이다. 즉 합 기호 밖에서는 존재하지 않는, 합 기호에 묶여있는 첨자이다. 따라서 좌변의 첨자로는 존재하지 않는다. 반면 $i$와 $k$는 성분 $C_{ik}$의 위치에 따라 결정되는 첨자로, 우변에서 특정 합 기호에 묶여있지 않은 자유 첨자라고 볼 수 있다.
식이 잘 정의되기 위해서는 모든 첨자가 항상 좌변의 첨자로 들어가 성분의 위치를 표현하는 자유 첨자이거나, 또는 우변에서 합 기호에 의해 묶인 첨자여야 함을 알 수 있다. 또한 $j$가 순회해야 하는 범위는 $1$부터 $m$까지임을 첨자의 위치로부터 알 수 있다. $j$는 $A$의 열과 $B$의 행에 들어가고, $A$의 크기는 $l \times m$, $B$의 크기는 $m \times l$이므로, $j$는 $1$부터 $m$까지 움직여야 한다.
따라서 합 기호를 제거하여
\[C_{ik} = A_{ij} B_{jk}\]와 같이 쓰더라도, 자유 첨자가 아닌(좌변에 존재하지 않는) $j$에 대해 $1$부터 $m$까지 더하는 합 기호가 생략되어 있음을 추론할 수 있다. 이것이 아인슈타인 표기법이다.
입력 텐서가 3개 이상이거나 동일한 첨자가 3번 이상 포함되는 것도 가능하다. 예를 들어,
\[d_{i} = a_{i} b_{i} c_{i}\]는 세 벡터 $a$, $b$, $c$의 성분별 곱이다.
입력 텐서에 동일한 첨자가 여러 번 들어가는 것도 가능하다. 예를 들어,
\[b = A_{ii}\]는 $i$에 대한 합 기호가 생략된 것으로, 행렬 $A$의 대각합(trace)이다 (단 $A$는 정방행렬이어야 한다). 이 경우 $A$의 대각선 성분만이 계산에 사용된다.
단 $B_{ii} = a_i$와 같이 결과 텐서에는 첨자가 중복될 수 없다. 첨자가 중복으로 들어갈 경우 무시되는 성분이 생기는데, 결과 텐서가 잘 정의되기 위해서는 모든 성분이 명시적으로 설정되어야 하기 때문이다. 이 표현식의 경우 대각선이 아닌 성분, 즉 $i \neq j$일 때 $B_{ij}$의 값이 정의되지 않으므로 잘못되었다.
합을 제거하는 과정에서 의문이 들 수도 있다. 합이 여러 개인 경우 표기법으로부터 원래의 합 순서를 복원할 수 없기 때문이다. 그러나 합의 순서는 결과에 영향을 주지 않으므로 문제가 되지 않는다. 첨자에 대한 합은 순서를 바꾸어도 결과가 동일하기 때문이다. 즉
\[\sum_{i=1}^m \sum_{j=1}^n f(i, j) = \sum_{j=1}^n \sum_{i=1}^m f(i, j)\]가 성립하기 때문이다.
표기법으로부터 원본 식을 복원하는 방법 요약
여기서는 아인슈타인 표기법으로부터 원본 식으로 복원하는 방법을 간단히 정리한다.
- 첨자들 중 연산 결과 성분의 첨자로 있는 것을 자유 첨자로 분류하고, 나머지를 묶인 첨자로 분류한다.
- 묶인 첨자가 표현하는 텐서 성분의 길이를 구한다. (예를 들어, 텐서가 $m \times n$ 크기의 행렬이고 첨자가 이 행렬의 열 위치를 나타낸다면 길이는 $n$이다.)
- 묶인 첨자에 해당하는 합 기호를 추가한다. 첨자의 범위는 $1$부터 위 과정에서 구한 길이까지이다.
einsum
함수
einsum
은 여러 라이브러리에서 아인슈타인 표기법에 따라 계산하는 함수의 이름으로 사용된다. 여기서는 NumPy를 기준으로 할 것이다. 다른 라이브러리에서도 기본적인 사용법은 같을 것이나 세부 사항이 다를 수 있음에 주의하자.
NumPy의 einsum
의 사용법은 다음과 같다.
입력 텐서의 개수는 총 $n$개이고, $a_k$는 $k$번째 입력 텐서를 뜻한다. $[i_k]$는 $k$번째 입력 $a_k$의 첨자 목록을 의미한다.
예를 들어, 행렬곱을 나타내는 $C_{ik} = A_{ij} B_{jk}$는 einsum
으로
np.einsum("ij,jk->ik", A, B)
로 표현된다.
NumPy의 einsum
함수는 타입 변환 방식, 최적화 방법 등 추가 설정을 할 수 있다. 자세한 내용은 매뉴얼에서 확인할 수 있다.
예시
Tim Rocktäschel의 글에서 다양한 예시를 볼 수 있다.
항등 변환 (Identity)
입력과 출력의 첨자를 동일하게 설정한다. 아래는 하나의 차원을 가졌을 때의 예이다.
\[b_i = a_i\]np.einsum("i->i", a)
성분의 합 (Sum of Entries)
출력의 첨자가 없도록 한다. 아래는 하나의 차원을 가졌을 때의 예이다.
\[b = a_{i}\]np.einsum("i->", a)
성분별 곱 (Element-wise Product)
두 입력과 출력의 첨자를 동일하게 설정한다. 아래는 하나의 차원을 가졌을 때의 예이다.
\[c_{i} = a_{i} b_{i}\]np.einsum("i,i->i", a, b)
전치 (Transpose)
\[B_{ij} = A_{ji}\]np.einsum("ji->ij", A)
대각합 (Trace)
\[c = A_{ii}\]np.einsum("ii->", A)
내적 (Dot Product)
\[c = a_{i} b_{i}\]np.einsum("i,i->", a, b)
외적 (Outer Product)
\[C_{ij} = a_{i} b_{j}\]np.einsum("i,j->ij", a, b)
선형 변환 (Linear Transformation)
\[c_{i} = A_{ij} b_{j}\]np.einsum("ij,j->i", A, b)
행렬곱 (Matrix Multiplication)
\[C_{ik} = A_{ij} B_{jk}\]np.einsum("ij,jk->ik", A, B)
배치 행렬곱 (Batch Matrix Multiplication)
배치(batch)별로 행렬곱을 수행하는 연산이다. PyTorch의 bmm
함수로 수행할 수 있다.
np.einsum("bij,bjk->bik", A, B)
이차형식 (Quadratic Form)
여기서 $A$는 정방행렬이어야 한다.
\[b = x_{i} A_{ij} x_{j}\]np.einsum("i,ij,j->", x, A, x)
활용
einsum
은 복잡한 연산을 간결하게 표현할 수 있다. einsum
을 쓰지 않는다면 반복문을 사용하거나 내적, 행렬곱과 같은 연산을 조합해야 하는데, 아인슈타인 표기법에 비해 복잡할 수 있다. 그러나 einsum
에 익숙하지 않은 사람의 입장에서는 오히려 가독성이 낮게 느껴질 수 있으므로 일장일단이 있다.
einsum
은 적절히 사용 시 효율성이 좋다. 파이썬 자체의 계산 효율이 낮기 때문에, 파이썬에서 직접 반복문으로 아인슈타인 표기법에 따라 계산하는 것은 매우 효율이 떨어진다. 그러나 einsum
은 NumPy나 PyTorch와 같은 라이브러리 내부의 최적화된 연산을 사용하므로 훨씬 효율적이다.
반면 라이브러리에서 지원하는 연산의 경우 einsum
으로 구현하는 것보다 더 효율적일 가능성이 높다. 예를 들어, NumPy의 경우 내적, 행렬곱, 대각합(trace) 등의 연산을 지원하는데, 이들은 einsum
으로도 구현할 수 있지만 제공되는 연산을 활용하는 것이 일반적으로 더 낫다. einsum
은 상당히 보편적인 연산을 수행할 수 있으므로, 이에 따라 자연스럽게 오버헤드가 발생할 수밖에 없기 때문이다.
그러나 수행하고자 하는 연산이 이러한 기본 연산(내적, 행렬곱 등) 한 번이 아니라 여러 개의 조합으로 표현되는 경우 einsum
이 더 효율적일 수 있다. 이러한 연산을 여러 번 조합할 경우, 계산 과정에 중간 결과물로써 생기는 결과값의 인스턴스에 의해 오버헤드가 발생할 수 있다. einsum
은 모든 연산 과정을 라이브러리 내부에서 하므로 이런 문제가 적을 것이다.
결론적으로, einsum
으로 구현하기 좋은 연산은 내적, 행렬곱과 같은 기본 텐서 연산으로는 표현하기 어렵거나 복잡한 것들이다. 이 경우 einsum
을 활용하면 코드를 간결하게 나타내고 효율적으로 계산할 수 있다. 이러한 연산은 특히 딥 러닝에서 높은 차원을 가진 텐서를 조작할 때 자주 사용된다.
Leave a comment