Karpathy-GPT만들기

Karpathy - Let's build GPT 실습하기 (2)

leejeong6 2025. 4. 25. 15:19

이제 Class로 만들어서 loss까지 구현합니다

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)

    def forward(self,idx,targets=None):
        logits = self.token_embedding_table(idx) # B,T,C = 32,8,65
        if targets is None:
            loss = None 
        else:    
            B,T,C = logits.shape
            
            logits = logits.view(B*T,C)
            
            
            targets = targets.view(B*T)
            
            loss = F.cross_entropy(logits,targets)
            
        return logits,loss

    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            logits,loss = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs,num_samples=1)
            idx = torch.cat((idx,idx_next),dim=1)
        return idx
    

m = BigramLanguageModel(vocab_size)
logits,loss = m(xb,yb)

print(loss.shape)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

nn.Embedding은 (x,y)크기의 테이블을 만들어둬서 input이 들어오면 테이블에 맞는 임베딩값을 뽑아주는 함수라고 생각하면 됩니다

Embedding(65,65)로 만들었다는건 임베딩 65*65행렬로 만들어져있다는거고, idx만큼 뽑아서 logits에 할당하면 logits은 원래 가지고 있던 행렬에 한 차원 더 생겨서 임베딩 값이 생기게 됩니다. 32*8 -> 32*8*65

다음은 cross_Entropy입니다

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

 

CrossEntropyLoss — PyTorch 2.7 documentation

Shortcuts

pytorch.org

 

읽어보면 C에 대해 Cross Entropy를 아래 식처럼 계산한다고 합니다.

minibatch,C 형태로 입력되어야 한다 하므로 

32*8*65가 아니라 256*65로 만들어서 입력합니다

target도 마찬가지로 이와 같은 과정을 하면 loss는 스칼라 값으로 나오게 됩니다

그러면 함수의 return인 logits는 256*65 loss는 스칼라가 되겠네요

 

다음은 generate함수입니다

위 forward식을 통해 나온 logits,loss를 이용하는데 logits를 [:,-1,:]으로 splicing합니다

그 이유는 logits는 현재 B*T,C인데 Text의 마지막 글자만 뽑아서 B,C형태로 만든다는 것입니다

그런 뒤에 Softmax를 취하는데, softmax에 dim이-1인것은 마지막 차원을 따라 softmax하겠다는 것입니다

즉 C를 따라 softmax하니까 B,C가 모두 확률로 나와있겠고 그 중에서 multinomial함수를 통해

가장 확률이 큰 num_samples 수만큼 뽑아줍니다. 그러면 idx_next도 B,1만큼 나올 것이고

torch.cat으로 idx와 합쳐주면 B,T+1차원으로 idx가 업데이트 됩니다. 

이게 max_new_tokens수만큼 반복되므로 결국 최종 idx는 B,T+100이 decode에 들어가게 되고 32개의 문장이 100개의 길이를 가진 채로 나오게 되겠네요

 

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
for steps in range(100): # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

이렇게 optimizer까지 정의해서 코드를 얼추 완성할 수 있습니다

여기서 제가 헷갈렸던 부분은 forward와 generate함수에서의 차원 계산인데요

디폴트로 targets가 None이기 때문에 generate함수를 실행했을 때는 if문만 실행돼서 logits는 항상 B,T,C차원이고 loss는 None이라는 점을 잊으면 안됩니다. 

 

출력결과를 보면 

oTo.JUZ!!zqe!
xBP qbs$Gy'AcOmrLwwt
p$x;Seh-onQbfM?OjKbn'NwUAW -Np3fkz$FVwAUEa-wzWC -wQo-R!v -Mj?,SPiTyZ;o-opr$mOiPJEYD-CfigkzD3p3?zvS;ADz;.y?o,ivCuC'zqHxcVT cHA
rT'Fd,SBMZyOslg!NXeF$sBe,juUzLq?w-wzP-h
ERjjxlgJzPbHxf$ q,q,KCDCU fqBOQT
SV&CW:xSVwZv'DG'NSPypDhKStKzC -$hslxIVzoivnp ,ethA:NCCGoi
tN!ljjP3fwJMwNelgUzzPGJlgihJ!d?q.d
pSPYgCuCJrIFtb
jQXg
pA.P LP,SPJi
DBcuBM:CixjJ$Jzkq,OLf3KLQLMGph$O 3DfiPHnXKuHMlyjxEiyZib3FaHV-oJa!zoc'XSP :CKGUhd?lgCOF$;;DTHZMlvvcmZAm;:iv'MMgO&Ywbc;BLCUd&vZINLIzkuTGZa
D.?

이런식으로 개판입니다.

그 이유는 아직 아무것도 학습이 되지 않았기 때문이죠

 

정리해보자면 

앞장에서는 인코딩 된 문장이 다음으로 나올 단어를 예측하는 것을 간단하게 실습해봤습니다

[11] -> 311
[11, 311] -> 10477
[11, 311, 10477] -> 1077
[11, 311, 10477, 1077] -> 3663
[11, 311, 10477, 1077, 3663] -> 26
[11, 311, 10477, 1077, 3663, 26] -> 369
[11, 311, 10477, 1077, 3663, 26, 369] -> 1077
[11, 311, 10477, 1077, 3663, 26, 369, 1077] -> 8571

위와 같이 말이죠.

근데 배치 단위로, 그리고 정답값에 가까워지게, 그리고 임베딩으로 표현하자는 것입니다

먼저 embedding table을 통해서 각 단어가 65크기의 임베딩 값을 가지게 합니다. 그런 뒤에

generate함수를 통해 위와 같이 max_new_tokens수만큼의 문장을 뽑게 됩니다.

idx가 어떻게 변하는지를 보면 아래와 같이 계속 생성되며 max_new_tokens까지 가게됩니다.

[[0, 31]]
[[0, 31, 23]]
[[0, 31, 23, 21]]
[[0, 31, 23, 21, 41]]
[[0, 31, 23, 21, 41, 24]]
[[0, 31, 23, 21, 41, 24, 32]]
[[0, 31, 23, 21, 41, 24, 32, 11]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12, 52]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12, 52, 55]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12, 52, 55, 7]]
[[0, 31, 23, 21, 41, 24, 32, 11, 13, 41, 17, 24, 25, 53, 32, 40, 60, 38, 60, 1, 15, 12, 52, 55, 7, 29]]

 

 

 

'Karpathy-GPT만들기' 카테고리의 다른 글

Karpathy - Let's build GPT 실습하기 (1)  (0) 2025.04.24