Stanford CS236 Deep Generative Models 수업의 자료를 기반으로 생성모델의 기본 개념들을 정리해보고자 한다. (참고 https://deepgenerativemodels.github.io/syllabus.html)
먼저 Autoregressive model 의 정의는 다음과 같다. (by charGPT)
“An autoregressive model is a type of statistical model used to describe certain time-dependent processes. It predicts the next value in a sequence based on the previous values, assuming that each value depends linearly on its predecessors and a stochastic term (random error). This approach is commonly used in time series analysis and language modeling, where each output is generated based on previously generated outputs.
자기회귀 모델(Autoregressive Model)은 시간에 따라 변화하는 데이터를 설명하는 데 사용되는 통계 모델 중 하나입니다. 이 모델은 이전 값들에 기초하여 다음 값을 예측하며, 각 값이 선형적으로 이전 값들과 랜덤 오차에 의존한다고 가정합니다. 주로 시계열 분석과 언어 모델링에서 사용되며, 이전에 생성된 출력에 기반하여 다음 출력을 생성합니다.”
즉, 쉽게 말해 이전 값들을 가지고 다음 값을 예측하는 모델을 autoregressive model 이라고 한다.
하나의 간단한 예시를 들어보자. MNIST 데이터셋에서 이미지의 확률분포를 얻는 문제이다.
확률분포 p(x)의 출력값은 x라는 이미지가 MNIST 데이터셋의 이미지와 비슷하게 생겼으면 높은 값을 가지도록 만들어야한다. x라는 이미지는 28×28로 이루어져있다고 하면 784개의 픽셀값(0 or 1)을 가지니까 p(x1, … x784)로 나타낸다. 기본적으로 chain rule을 이용하면 p(x)는 아래와 같이 나타내질 수 있다.
위 식을 보면 결국 p(x)는 각 픽셀들의 확률의 곱으로 구해지는데 각 픽셀들의 확률은 이전 픽셀들의 값이 뭐냐에 따라 달라질테니까 이전 픽셀들 값을 condition으로 받게 된다. 위에서 얘기한 autoregressive 모델이 나오게 되는것이다. 각 픽셀에 대해 모두 autoregressive model로 확률을 계산한 값을 모두 곱하면 최종적으로 이미지가 MNIST와 유사한지 안한지에 대한 확률을 계산할 수 있게 되는것이다. 이러한 모델을 직관적으로 디자인 하면 아래와 같은 형태로 할 수 있고 이것이 FVSBN(Fully Visible Sigmoid Belief Network) 이다.
이 모델의 파라미터인 \( \alpha \) 값들을 MNIST 이미지들 가지고 학습을 시키면 이 모델은 픽셀값들이 주어졌을 때 다음 픽셀값이 1일 확률을 예측하게 되는 모델이 된다.(학습 방법은 일단 생략, 다음 포스팅에서 다룸) 즉, 픽셀값들이 주어졌을 때 다음 픽셀값이 MNIST스러우면 높은 값이 나올것이고 아니면 낮은 값이 나올 것이다. 이렇게 잘 학습시킨 모델이 있다고 하면 이 모델을 가지고 새로운 sample을 generate할 수 있게 되는데 sampling 과정은 아래와 같다.
일단 x1을 p(x1) 확률을 가지고 sampling한다. (p(x1)은 단순히 \( \alpha^1 \) 이고 \( \alpha^1 \) 값이 0.7이면 10번중에 7번은 1, 3번은 0으로 sampling 될것이다. )
이후 x1 값을 가지고 x2 픽셀의 확률을 예측하고 이 확률대로 sampling한다.
또 이렇게 sampling된 x1,x2 값을 가지고 x3 값의 확률을 예측, 확률대로 sampling한다. 이를 반복하면 최종 이미지를 얻을 수 있다.
사실상 autoregressive 모델을 모두 이 원리를 따라 학습을 하고 sample을 생성을 하게 된다. 여기서 구조만 조금씩 바뀌는데 위의 FVSBN은 이전 픽셀값에 대한 wighted linear sum을 하고 sigmoid계산을 한게 끝이고 이 대신 안에 neural network를 넣어 Wx+b 형태로 계산되게 하고 같은 위치의 픽셀에 대해서는 계속 같은 weight가 적용되어 계산하게 하면(x1에는 항상 w1이 적용, FVSBN은 x1에 대한 weight인 \( \alpha \) 값이 x2, x3를 계산할때 매번 달라져서 파라미터가 계속 증가했었음) 아래와 같이 NADE (Neural Autoregressive Density Estimation) 모델이 된다.
위에서는 픽셀값이 0 or 1인 경우만 다루고 있어서 output이 확률이면 그냥 그 확률대로 0 or 1을 sampling하면 끝인데 이 픽셀값이 0~255이거나 continuous일때는 어떻게 구조를 바꿔야할까?
먼저 0~255개의 discrete한 값을 sampling하고 싶으면 output을 그냥 1개의 확률로 뽑는게 아니라 256개의 output으로 뽑고 softmax를 취하면 각 0~255 값에 대한 확률이 나오게 되니까 그 확률대로 sampling을 해주면 된다.
Continuous한 경우에는 output을 여러개의 가우시안 확률분포의 mean, std로 만들어서 뽑은 다음에 이 가우시안들을 합해서 mixture of Gaussian 분포에서 sampling을 해주면 된다.
이 autoregressive 모델들을 보다보면 autoencoder와 형태가 비슷하다. 따라서 autoencoder를 autoregressive 모델처럼 생성모델로 쓰려고 한것이 아래의 MADE (Maskted Autoencoder for Distribution Estimation) 이다.
방법은 위와 같은데 요약하면
- 먼저 모든 unit에 숫자를 매기는데 x1~xn 까지는(output, input 동일하게) 1~n 까지 숫자를 각각 배정하고 중간 unit은 그냥 1~n중에 random으로 숫자를 매긴다.
- 최종 output layer에서는 각각 자신보다 낮은 숫자를 가진 unit과만 연결한다.
- 그 아래 layer에서는 자신과 같거나 낮은 숫자를 가진 unit과 연결한다.
이렇게 하면 최종 ouptut 에서는 input에서 본인보다 작은 숫자를 가진 값들과만 연결된다. 즉 autoregressive 형식으로 output을 얻게 되는것이다. 그러면 이 autoencoder 모델로 x1부터 하나하나 sampling해나가면 새로운 sample을 얻을 수 있다!
여튼 이런 autoregressive 모델중에 대표적인 모델 중에 하나로 RNN이 있다.
장단점으로는 일단 입력 sequence가 길이가 얼마나 되든 잘 학습된 충분한 길이의 RNN이 있으면 그 안에서 처리가 되기 때문에 입력 길이가 자유롭다는 점이 있다.
단점으로는 autoregressive의 어쩔수없는 단점으로 sequential하게 값을 처리해야하기 때문에 병렬화가 어렵다. (xn을 얻기 위해서는 x1, x2 … xn-1 까지 얻어야 계산이 가능하므로 이를 기다려야하기 때문)
그리고 이러한 autoregressive 모델을 이용해서 RGB 이미지를 생성하는것이 아래의 Pixel RNN이 된다.
전체 확률을 r값의 확률 ,g값의 확률, b값의 확률로 각각 곱해서 얻으면서 각각을 또 autoregressive하게 얻어내는 구조이다. 이러한 autoregressive구조를 CNN으로도 구현할 수 있는데 아래가 이에 대한 PixelCNN이다. convolution filter를 아래와 같이 디자인해서 자기 픽셀값 이전 픽셀값들만 입력으로 볼 수 있게 한 것인데 CNN특성상 blind spot이 생기긴 한다.
이런 생성모델을 가지고 p(x)를 잘 학습시켜놓으면 adversarial example이 들어오면 p(x)값이 확 낮아져서 attack을 감지할수 있다. 또한 attack이 들어간 이미지를 pixelCNN이나 이런 생성모델로 새로 생성하면 이미지의 attack을 제거할수도 있다.
마지막으로 speech에서도 WaveNet같은 모델이 autoregressive를 이용한다.
답글 남기기
댓글을 달기 위해서는 로그인해야합니다.