개요
RNN에서 나타나는 데이터의 차원을 이해하고
LSTM cell의 output logit을 의미하는 ox가 담긴 ox_batch를 이용해서
loss 값을 구하려면 어떤 shape을 맞춰주어야 하는지 알아보자.
코드
ox_batch = torch.cat(ox_batch).reshape(max_seq_len, batch_size, -1) # (50, 64, d)
ox_batch = ox_batch.permute(1,0,2).reshape(batch_size*max_seq_len, -1) # (64, 50, d).reshape(3200, d)
y_batch = y_batch.reshape(-1)
ox_batch: LSTM cell의 logit인 ox가 담겨있다.
ox_batch는 원래 list형이었다.
과정
1. list였던 ox_batch를 cat method를 이용해서 쌓아 torch형으로 만든다.
이때, shape은 max_seq_len x batch_size x voca_dim으로 한다.
2. 하지만 이를 text를 담아둘 때 기본 shape인 batch_size x max_seq_len x voca_dim으로 한다.
permute를 이용하면 각 차원이 index가 되어 순서를 바꿔줄 수 있다.
-> 만약 shape이 (10, 5, 1)이라고 하면 permute(2, 0, 1)을 해주면 shape이 (1, 10, 5)로 변한다.
3. 그리고 reshape을 통해 Large batch x voca_dim으로 shape을 변경해준다.
이렇게 해주어야 loss function의 input에 맞는 shape를 갖출 수 있다.
BxSxD shape으로는 loss function에 input을 해줄 수가 없다.
4. y_batch도 마찬가지로 loss function으로의 입력을 위해 한 줄로 펴준다.
-> shape이 Large batch x 1로 변한다.
정리
ox_batch의 shape 변화
S x B x D
⬇️
B x S x D(RNN 기본형)
⬇️
Large B x D(for objective function)
y_batch의 shape 변화
B x S
⬇️
Large B x 1(for objective function)
'자연어 처리 과정' 카테고리의 다른 글
Why RNN share the same weights? (0) | 2022.12.28 |
---|---|
Word2vec vs GloVe (0) | 2022.12.27 |
LSTM (0) | 2022.12.18 |
RNN (0) | 2022.12.18 |
VGG (0) | 2022.12.18 |