본문 바로가기

자연어 처리 과정

Time sequence로 정렬하기

개요

RNN에서 나타나는 데이터의 차원을 이해하고

LSTM cell의 output logit을 의미하는 ox가 담긴 ox_batch를 이용해서

loss 값을 구하려면 어떤 shape을 맞춰주어야 하는지 알아보자.

 

코드

ox_batch = torch.cat(ox_batch).reshape(max_seq_len, batch_size, -1)    # (50, 64, d)
ox_batch = ox_batch.permute(1,0,2).reshape(batch_size*max_seq_len, -1) # (64, 50, d).reshape(3200, d)
y_batch = y_batch.reshape(-1)

ox_batch: LSTM cell의 logit인 ox가 담겨있다.

ox_batch는 원래 list형이었다.

 

과정

1. list였던 ox_batch를 cat method를 이용해서 쌓아 torch형으로 만든다.

이때, shape은 max_seq_len x batch_size x voca_dim으로 한다.

 

2. 하지만 이를 text를 담아둘 때 기본 shape인 batch_size x max_seq_len x voca_dim으로 한다.

permute를 이용하면 각 차원이 index가 되어 순서를 바꿔줄 수 있다.

 

-> 만약 shape이 (10, 5, 1)이라고 하면 permute(2, 0, 1)을 해주면 shape이 (1, 10, 5)로 변한다.

 

3. 그리고 reshape을 통해 Large batch x voca_dim으로 shape을 변경해준다.

이렇게 해주어야 loss function의 input에 맞는 shape를 갖출 수 있다.

BxSxD shape으로는 loss function에 input을 해줄 수가 없다.

 

4. y_batch도 마찬가지로 loss function으로의 입력을 위해 한 줄로 펴준다.

-> shape이 Large batch x 1로 변한다.

 

정리

ox_batch의 shape 변화

S x B x D

 

⬇️

 

B x S x D(RNN 기본형)

 

⬇️

 

Large B x D(for objective function)

 

 

 

 

y_batch의 shape 변화

B x S

 

⬇️

 

Large B x 1(for objective function)

 

 

 

'자연어 처리 과정' 카테고리의 다른 글

Why RNN share the same weights?  (0) 2022.12.28
Word2vec vs GloVe  (0) 2022.12.27
LSTM  (0) 2022.12.18
RNN  (0) 2022.12.18
VGG  (0) 2022.12.18