본문 바로가기

자연어 처리 과정

Training with SAM optimizer

개요

Accumulation을 사용하는 상황에서 sam optimizer를 사용할 때 어떻게 학습 코드를 작성하면 좋을지 알아보자!

Task는 QA(mrc) task로 설정한다.

 

1. Sam optimizer 기본 사용법

2. Accumulation을 적용한다면?

3. 마무리

 

 

1. Sam optimizer 기본 사용법

Sam optimizer를 사용하는 데에 있어서 다른 optimizer와의 차이점은 바로 두 가지가 있다.

 

- loss backward를 2번 밟는다.

- optimizer의 step을 2번 밟는다.

 

그렇다면 이에 대한 코드를 어떻게 작성해야 하는 걸까?

 

정말 다행스럽게도, 이곳에 sam optimizer를 사용하는 기본적인 방법에 대해 설명이 되어 있다.

 

출처: https://github.com/davda54/sam

위에서 설명했던 차이점 그대로 model의 inference에 대한 loss를 총 2번 구하고, 

parameter updating을 위한 optimizer step을 2번 밟아야 한다.

 

2. Accumulation을 적용한다면?

우리는 학습을 할 때, 제한된 gpu 자원 속에서도 큰 batch의 gradient를 얻을 수 있도록 하기 위해

accumulation을 사용한다.

 

그렇다면, 이때 sam optimizer를 사용하려면 어떻게 코드를 작성해야 할까?

 

(1) 전체 코드

아래에 있는 코드는 모든 학습 과정을 담은 코드이다.

accumulation을 적용하는 데에 있어서 에러가 포함된 코드이다.

어떤 에러가 포함되어 있을까?

train_step_losses = []
dev_step_losses = []
train_losses = []
dev_losses = []
lowest_dev_loss = 9999

epochs = 3
first_step = 0
second_step = 0

for epoch in range(epochs):
    print("Epoch", epoch)
    # Training
    running_loss = 0.
    losses = []
    progress_bar = tqdm(train_dataloader, desc='Train')
    for input, start, end, answer_text, positions, context in progress_bar:
        model.train()
        del answer_text, positions, context
        input.to(device)
        start = start.cuda()
        end = end.cuda()
        
        outputs = model(**input, start_positions=start, end_positions=end)
        loss = outputs.loss
        (loss / accumulation).backward()
        running_loss += loss.item()
        
        first_step += 1
        if first_step % accumulation: # step % acc == 0이 아니면 다시 backward하러 돌아가게끔
            continue
        
        clip_grad_norm_(model.parameters(), max_norm=1.)
        optimizer.first_step(zero_grad=True)

        for _ in range(accumulation):
            second_step += 1
            (model(**input, start_positions=start, end_positions=end).loss / accumulation).backward()

            if second_step % accumulation == 0:
                optimizer.second_step(zero_grad=True)
        scheduler.step()
        # optimizer.zero_grad(set_to_none=True)

        del input, start, end, loss

        train_step_losses.append(running_loss / accumulation)
        losses.append(running_loss / accumulation)
        running_loss = 0.
        progress_bar.set_description(f"Train - Loss: {losses[-1]:.3f}")
    train_losses.append(mean(losses))
    print(f"train score: {train_losses[-1]:.3f}")

    val_losses = []
    for input, start, end, answer_text, positions, context in tqdm(valid_dataloader, desc="Evaluation"):
        model.eval()
        del answer_text, positions, context
        input.to(device)
        start = start.cuda()
        end = end.cuda()
        
        with torch.no_grad():
            outputs = model(**input, start_positions=start, end_positions=end)
        loss = outputs.loss

        dev_step_losses.append(loss.item())
        val_losses.append(loss.item())
        
        del input, start, end, loss
    dev_losses.append(mean(dev_step_losses))
    print(f"Evaluation score: {dev_losses[-1]:.3f}")

    if lowest_dev_loss > dev_losses[-1]:                    
        lowest_dev_loss = dev_losses[-1]
        # torch.save(model.state_dict(), "./model_name.bin")
        model.save_pretrained(f'./model_name'

 

(2) 에러 부분

Sam optimizer의 특징인 first_step, second_step을 밟는 부분들을 살펴보자.

첫 번째 step을 밟는 과정은 정상적으로 보인다.

자세히 살펴볼 부분은 second_step을 밟는 부분이다.

지금 코드에서는 for문을 통해 한 batch에 담긴 input이 모델에게 제공되고 있다.

batch size가 8이라고 가정해보자.

 

그럼 첫 번째 step에서의 output에는 당연하게도 8개의 데이터에 대한 모델의 예측값이 나올 것이다.

그리고 8개 데이터를 받고 예측을 한 것에 대한 loss 값이 backward 되며 그에 대한 gradient가 구해져

sam optimizer의 first step을 밟게 된다.

즉, 정리를 하자면,

 

1. 8개 데이터에 대한 모델의 output

2. output의 loss에 대한 gradient를 backward통해 구한다.

3. parameter를 updating through first step

 

그런데 두 번째 step을 밟는 과정을 자세히 살펴보자.

두 번째 loss를 구할 때 모델에 들어가는 input은 어떤 데이터들일까?

지금 이 방식이 정상적으로 batch 내의 모든 input을 넣어주고 있는 걸까?

 

(3) 에러 이유

잘 생각해보면, 현재 second_loss를 구하기 위해서 들어가고 step의 횟수만큼 모델에 들어가고 있는 input은

마지막 배치에 들어있는 데이터들이다.

Accumulation번째의 즉, 마지막 배치의 데이터 8개만이 second step을 밟는 데에 이용되고 있다.

 

정상적인 gradient 정보를 받아서 parameter를 업데이트 할 가능성이 매우 낮다.

 

(4) 에러 해결 코드

따라서 우리는 아래와 같이 코드를 수정해야 한다.

즉, accumulation step을 밟기 전까지의 input을 모두 리스트에 저장해둔 다음

second_step을 밟을 때, 리스트에 담긴 input 정보를 이용해서 gradient를 구해 parameter를 업데이트 해야 한다.

 

 

3. 마무리

오늘은 이렇게 sam optimizer의 사용법에 대해 알아보았다.

다음에는 sam optimizer가 어떻게 sharp minimizer를 flat minimizer로 바꾸어 일반화 능력을 높이는 건지 알아보도록 하자.