admin管理员组

文章数量:1594249

一、介绍

torchtext在0.9.0之后,删除了legacy、Field等接口,旨在将其打造成类似于torch的DataLoader的库,每一个过程都可以自己控制。但是这样以来,之前的代码很多都不能用了,而且目前新版API的使用没有一份详细的攻略,我在多方查阅后,写了一些自己的理解。对我帮助比较大的是这篇博客Torchtext 0.12+新版API学习与使用示例(1),我将在此基础上继续改进。

二、旧版API

为了读者更好的理解,我们这里仍然给出完整的旧版API的官方代码,感兴趣的可以看一下,想用新版可以直接跳过。由于旧版我没有用过多少,并且 我现在使用torch 2.0.1,安装不了旧版的torchtext了,这里没有给出运行效果。

# 旧版legacy API使用
import torchtext
import torch
from torchtext.legacy import data
from torchtext.legacy import datasets

# Step 1 Create a dataset object
TEXT = data.Field() 
LABEL = data.LabelField(dtype = torch.long)
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL) # 使用datasets中的IMDB数据集
legacy_examples = legacy_train.examples # 拿出一个样本,输出查看
print(legacy_examples[0].text, legacy_examples[0].label)

# Step 2 Build the data processing pipeline
"""
The default tokenizer implemented in the Field class is the built-in python split() function. Users choose the tokenizer by calling data.get_tokenizer(), and add it to the Field constructor. For the sequence model, it's common to append <BOS> (begin-of-sentence) and <EOS> (end-of-sentence) tokens, and the special tokens need to be defined in the Field class.
"""
TEXT = data.Field(tokenize=data.get_tokenizer('basic_english'),
                  init_token='<SOS>', eos_token='<EOS>', lower=True) # tokenize是我们分词的方法,可以自己定义一个类或方法来实现
LABEL = data.LabelField(dtype = torch.long)
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)  # datasets here refers to torchtext.legacy.datasets

"""
Now you can create a vocabulary of the words from the text file stored in the predefined Field object, TEXT. You fist have to build a vocabulary in your Field object by passing the dataset to the build_vocab func. The Field object builds the vocabulary (TEXT.vocab) on a specific data split.
"""
TEXT.build_vocab(legacy_train) # 通过build_vocab()函数建立单词表
LABEL.build_vocab(legacy_train)
"""
Things you can do with a vocabuary object
    Total length of the vocabulary
    String2Index (stoi) and Index2String (itos)
    A purpose-specific vocabulary which contains word appearing more than N times
"""
legacy_vocab = TEXT.vocab
print("The length of the legacy vocab is", len(legacy_vocab)) # 查看单词数量
legacy_stoi = legacy_vocab.stoi
print("The index of 'example' is", legacy_stoi['example']) # word -> token
legacy_itos = legacy_vocab.itos 
print("The token at index 686 is", legacy_itos[686]) # token -> word

# Set up the mim_freq value in the Vocab class
TEXT.build_vocab(legacy_train, min_freq=10) # 截断频率,删除出现次数少于10的单词
legacy_vocab2 = TEXT.vocab
print("The length of the legacy vocab is", len(legacy_vocab2))

# Step 3: Generate batch iterator
"""
The legacy Iterator class is used to batch the dataset and send to the target device, like CPU or GPU.
"""
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)  # datasets here refers to torchtext.legacy.datasets
legacy_train_iterator, legacy_test_iterator = data.Iterator.splits(
    (legacy_train, legacy_test), batch_size=8, device = device)
"""
For a NLP workflow, it's also common to define an iterator and batch texts with similar lengths together. The legacy BucketIterator class in torchtext library minimizes the amount of padding needed.
"""
from torchtext.legacy.data import BucketIterator
legacy_train, legacy_test = datasets.IMDB.splits(TEXT, LABEL)
legacy_train_bucketiterator, legacy_test_bucketiterator = data.BucketIterator.splits(
    (legacy_train, legacy_test),
    sort_key=lambda x: len(x.text),
    batch_size=8, device = device)

# Step 4: Iterate batch to train a model
"""
The legacy batch iterator can be iterated or executed with next() method.
"""
next(iter(legacy_train_iterator))

三、新版API

新版torchtext API的使用是本文的重点,下面将给出完整的官方使用代码、详细的解释、运行的过程。

2.1 Step 1: Create a dataset object

The new dataset API returns the train/test dataset split directly without the preprocessing information. Each split is an iterator which yields the raw texts and labels line-by-line.

新版的API创建数据集时不再使用Field,而是使用train/test,将数据集的读取和处理分开了。

from torchtext.datasets import IMDB
train_data, test_data = IMDB(split=('train', 'test'))
train_iter, test_iter = iter(train_data), iter(test_data)
# To print out the raw data, you can call the next() function on the IterableDataset.
print(next(train_iter))

上面的代码运行结果为一个元组,第一个元素是label,第二个元素是text。

(1, 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it\'s not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn\'t have much of a plot.')

或者我们可以定义自己的数据集(看到这里相信各位会有似曾相识的感觉):

from torch.utils.data import Dataset, DataLoader

sentences = ["I am happy"
             "I am angry"]
labels = [0, 1]


class MyDataset(Dataset):
    def __init__(self, text, label):
        self.text = text
        self.label = label

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item):
        return self.text[item], self.label[item]

data_iter = iter(MyDataset(text=sentences, label=labels))
print(next(data_iter))

2.2 Step 2 Build the data processing pipeline

Users have the access to different kinds of tokenizers directly via data.get_tokenizer() function.

我们可以使用data.get_tokenizer()函数来创建tokenizer,也可以自定义一个tokenizer()

# 方法一
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('basic_english')

# 方法二
# or define your tokenizer
import spacy
import re

class tokenize(object):
    
    def __init__(self, lang):
        self.nlp = spacy.load(lang)
            
    def tokenizer(self, sentence):
        sentence = re.sub(
        r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
        sentence = re.sub(r"[ ]+", " ", sentence)
        sentence = re.sub(r"\!+", "!", sentence)
        sentence = re.sub(r"\,+", ",", sentence)
        sentence = re.sub(r"\?+", "?", sentence)
        sentence = sentence.lower()
        return [tok.text for tok in self.nlp.tokenizer(sentence) if tok.text != " "]

To have more flexibility, users can build the vocabulary directly with the Vocab class. For example, the argument min_freq is to set up the cutoff frequency to in the vocabulary. The special tokens, like <BOS> and <EOS> can be assigned to the special symbols in the constructor of the Vocab class.

为了更加灵活,新版本的torchtext可以直接使用Vocab类来创建单词表,使用min_freq来设置截止频率,特殊token也可以通过special添加进去,并且specials中的字符的token分别为0,1,2,...,我们还可以通过set_default_index()设置单词表中没有出现的单词的值。

from collections import Counter
from torchtext.vocab import vocab # 注意导入的是vocab,不是Vocab

train_data = IMDB(split='train')
train_iter = iter(train_data)
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
vocab = vocab(ordered_dict , min_freq=10, specials=['<unk>', '<BOS>', '<EOS>', '<PAD>'])

"""
假如
counter = Counter(["a", "a", "b", "b", "b"])
counter.items()

Out[19]: dict_items([('a', 2), ('b', 3)])
Counter会计数单词出现的次数
"""

"""
我尝试了直接把counter传入vocab也是可以的,但是不确定后续的影响
"""

"""
这个时候我们就可以开始进行转换了,比如
print(vocab["love"])
但是如果出现单词表中不存在的单词就会报错
"""

# adding <unk> token and default index
vocab.set_default_index(-1) # 这样之后,如果出现单词表中的单词,就会返回-1
# special中已经定义了'<unk>'字符用来表示未知的字符,我们更希望出现未知字符后返回vocab['<unk>']
vocab.set_default_index(vocab['<unk>'])

 除了上面的方式,torchtext.vocab模块中还有build_vocab_from_iterator()函数,其源码实际上就是实现了Counter()、OrderedDict()和vocab(),所以上面的代码又等价于:

from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(train_iter):
    for (label, line) in train_iter:
        yield line.strip().split()

word_vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq=10, specials=['<unk>', '<BOS>', '<EOS>', '<PAD>'])

word_vocab.set_default_index(word_vocab['<unk>'])

Both text_transform and label_transform are the callable object, such as a lambda func here, to process the raw text and label data from the dataset iterators. Users can add the special symbols <BOS> and <EOS> to the sentence in text_transform.

创建完单词表之后,我们就要开始数据的转换了,text_transform和label_transform都是可调用对象,可以通过lamda函数进行定义:

text_transform = lambda x: [vocab['<BOS>']] + [vocab[token] for token in tokenizer(x)] + [vocab['<EOS>']]
label_transform = lambda x: 1 if x == 'pos' else 0

# Print out the output of text_transform
print("input to the text_transform:", "here is an example")
print("output of the text_transform:", text_transform("here is an example"))

"""
out:
input to the text_transform: here is an example
output of the text_transform: [1, 227, 9, 35, 711, 2]
"""

2.3 Step 3: Generate batch iterator

torch.utils.data.DataLoader is used to generate data batch. Users could customize the data batch by defining a function with the collate_fn argument in the DataLoader. Here, in the collate_batch func, we process the raw text data and add padding to dynamically match the longest sentence in a batch.

新版API可以与DataLoader接口一起使用,我们需要定义一个collate_batch()函数来处理每批次的数据。

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    """
    padding_value将这个批次的句子全部填充成一样的长度,padding_value=word_vocab['<PAD>']=3
    """
   label_list, text_list = [], []
   for (_label, _text) in batch:
        label_list.append(label_transform(_label))
        processed_text = torch.tensor(text_transform(_text))
        text_list.append(processed_text)
   return torch.tensor(label_list), pad_sequence(text_list, padding_value=3.0)

train_iter = IMDB(split='train')
train_dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True, 
                              collate_fn=collate_batch)

To group the texts with similar length together, like introduced in the legacy BucketIterator class, first of all, we randomly create multiple "pools", and each of them has a size of batch_size * 100. Then, we sort the samples within the individual pool by length. This idea can be implemented succintly through batch_sampler argument of PyTorch Dataloaderbatch_sampler accepts 'Sampler' or Iterable object that yields indices of next batch. In the code below, we implemented a generator that yields batch of indices for which the corresponding batch of data is of similar length.

我们可以采样batch*100个样本,并且对句子长度进行排序,将句子长度接近的放在一起,我们可以定义一个batch_sampler函数来实现这一功能。

import random

train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8  # A batch size of 8

def batch_sampler():
    indices = [(i, len(tokenizer(s[1]))) for i, s in enumerate(train_list)]
    random.shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
                               collate_fn=collate_batch)

print(next(iter(bucket_dataloader)))

四、参考文献

如果英语不错并且时间充裕的话,一定要看下面这两个官方文档。

新旧API对比

torchtext 0.15.0

本文标签: 新版torchtextAPI