torchtext

Reference : Torchtext Tutorial

처음에 왜 이렇게 어려워 … 하 파이토치 언제 적응하냐… 했었는데 너무 감사하게도 잘 정리해준 어느 분의 블로그를 발견하였다. 🙇‍♀️ 2018년 글이어서 이후 버전에서 조금씩 변경된 사항이 있을 수는 있을 것 같다. 차후에 더 알게 된 내용들을 정리해보아야겠다

STEP 1. Field ⛳️

Field 는 추후 텐서로 바뀔 텍스트 데이터를 처리 하는 기능을 한다. 여기서 처리에는 토큰화 등의 기능이 있다. 공식 도큐멘트를 살펴보면 이런 이런 설정들을 할 수 있다

Field 공식 도큐멘트

  • use_vocab – Whether to use a Vocab object. If False, the data in this field should already be numerical. Default: True.
  • tokenize – The function used to tokenize strings using this field into sequential examples. If “spacy”, the SpaCy tokenizer is used. If a non-serializable function is passed as an argument, the field will not be able to be serialized. Default: string.split.
  • batch_first – Whether to produce tensors with the batch dimension first. Default: False.
  • stop_words – Tokens to discard during the preprocessing step. Default: None
  • sequential – Whether the datatype represents sequential data. If False, no tokenization is applied. Default: True.

그러니까 텍스트 데이터들 중 같이 한 번에 처리할 애들끼리 각 각 Field 를 만들어주면 된다. @simonjisu 님 이 해당 글에서 작성해주신 코드를 가져왔다.

from torchtext.data import Field

TEXT = Field(sequential=True,
             use_vocab=True,
             tokenize=str.split,
             lower=True, 
             batch_first=True)  
LABEL = Field(sequential=False,  
              use_vocab=False,   
              preprocessing = lambda x: int(x),  
              batch_first=True)

STEP 2. 데이터 셋 만들기 (feat. TabularDataset)

TabularDataset 은 데이터를 지정해준 경로에서 불러와 위에서 만들어준 Field에 처리해준다. [('필드이름(임의지정)', 필드객체), ('필드이름(임의지정)', 필드객체)] 로 넣어준다.

from torchtext.data import TabularDataset

train_data = TabularDataset.splits(path='./data/',
					train='train_path',
					valid='valid_path',
					test='test_path',
					format='tsv', 
					fields=[('text', TEXT), ('label', LABEL)])
train_data = TabularDataset(path='./data/examples.tsv', 
				format='tsv', 
				fields=[('text', TEXT), ('label', LABEL)])

STEP 3. 단어장 생성

<unk> : 0 , <pad> : 1

토큰과 integer index 를 매칭시켜준다.

TEXT.build_vocab(train_data)

STEP 4. 데이터 로더 만들기

  1. 일반적 Iterator : 여기서 return 되는 train_loader, valid_loader 부터는 이제 익숙한 파이토치 문법이다. 차례로 batch 돌리면서 학습시키면 된다.

    from torchtext.data import Iterator
       
    train_loader, valid_loader, test_loader = \
    	TabularDataset.splits((train_data, valid_data, test_data), 
    				batch_size=3, 
    				device=None,  # gpu 사용시 "cuda" 입력
    				repeat=False)
    
    train_loader = Iterator(train_data, 
    			batch_size=3, 
    			device=None,  # gpu 사용시 "cuda" 입력
    			repeat=False)
    
  2. BucketIterator : 해당 batch 내에서 비슷한 길이의 문장끼리 batch 를 구성할 수 있도록 sort_key 를 제공한다. 🍇 Better Batches with PyTorchText BucketIterator

​ 좀 삐까뻔쩍한 기능인 것 같다 !

​ 이런 식으로 사용할 수 있다.

torchtext_train_dataloader, torchtext_valid_dataloader = torchtext.data.BucketIterator.splits(
    
                              # Datasets for iterator to draw data from
                              (train_dataset, valid_dataset),

                              # Tuple of train and validation batch sizes.
                              batch_sizes=(train_batch_size, valid_batch_size),

                              # Device to load batches on.
                              device=device, 

                              # Function to use for sorting examples.
                              sort_key=lambda x: len(x['text']), #필드 이름


                              # Repeat the iterator for multiple epochs.
                              repeat=True, 

                              # Sort all examples in data using `sort_key`.
                              sort=False, 

                              # Shuffle data on each epoch run.
                              shuffle=True,

                              # Use `sort_key` to sort examples in each batch.
                              sort_within_batch=True,
                              )

# Print number of batches in each split.
print('Created `torchtext_train_dataloader` with %d batches!'%len(torchtext_train_dataloader))
print('Created `torchtext_valid_dataloader` with %d batches!'%len(torchtext_valid_dataloader))

데이터 로더까지 만들어줬으면 이후 사용은 내가 익숙한 평소 파이토치 학습 순서대로 간다.

for batch in train_loader:
    break
print(batch.text)
print(batch.label)

해당 코드를 제공해주신 블로거 분께 다시 한 번 감사드린다. 🙇‍♀️