admin管理员组

文章数量:1642443

一、准备数据

seq = "I love you. Chinese vocabulary is generally used to express one's feelings to another person whom one admires. It can also be used among relatives. It is the expression of one person's feelings to another. It can also be used to express things with strong feelings, such as pets and goods. It can be said by boys to girls, girls to boys, girls to girls, boys to boys."

此后需要将数据转换为小写并且去除标点符号,保留空格,并且建立字母索引表,如下所示:

index2word = {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y', 25: 'z', 26: ' '}
word2index = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, ' ': 26}

再将seq中的使用index表示,例如 :

"i love" = ['i', ' ', 'l', 'o', 'v', 'e'] = [8, 26, 11, 14, 21, 4]

 最后设置窗口大小,例如:每5个字母预测下一个字母,设置window = 5,如图:

 二、模型

  • 输入采用embedding生成词向量输入
  • 模型采用双向LSTM接一个LSTM,将LSTM最后一个隐层作为全连接层的输入
  • 此模型图如下所示:

三、具体代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
from torch.autograd import Variable


seq = "I love you. Chinese vocabulary is generally used to express one's feelings to another person whom one admires. It can also be used among relatives. It is the expression of one person's feelings to another. It can also be used to express things with strong feelings, such as pets and goods. It can be said by boys to girls, girls to boys, girls to girls, boys to boys."
letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ']
# 转为小写并去除标点符号
seq_lower = [i for i in seq.lower() if i in letters]

word2index = {}
id = 0
for i in letters:
    word2index[i] = id
    id += 1
index2word = {value: key for key, value in word2index.items()}

# 句子索引
seq_index = [word2index[i] for i in seq_lower]
seq_length = len(seq_index)
window = 3
# 生成输入数据
batch_x = []
batch_y = []
for i in range(seq_length - window + 1):
    x = seq_index[i: i + window]
    if i + window >= seq_length:
        y = word2index[' ']
    else:
        y = seq_index[i + window]
    batch_x.append(x)
    batch_y.append(y)

# 训练数据
batch_x, batch_y = Variable(torch.LongTensor(batch_x)), Variable(torch.LongTensor(batch_y))

# 参数
vocab_size = len(letters)
embedding_size = 16
n_hidden = 32
batch_size = 10
num_classes = vocab_size

dataset = Data.TensorDataset(batch_x, batch_y)
loader = Data.DataLoader(dataset, batch_size, shuffle=True)

# 建立模型
class BiLSTM(nn.Module):
    def __init__(self):
        super(BiLSTM, self).__init__()
        self.word_vec = nn.Embedding(vocab_size, embedding_size)
        # bidirectional双向LSTM
        self.bilstm = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)
        self.lstm = nn.LSTM(2 * n_hidden, 2 * n_hidden, 1, bidirectional=False)
        self.fc = nn.Linear(n_hidden * 2, num_classes)

    def forward(self, input):
        embedding_input = self.word_vec(input)
        # 调换第一维和第二维度
        embedding_input = embedding_input.permute(1, 0, 2)
        bilstm_output, (h_n1, c_n1) = self.bilstm(embedding_input)
        lstm_output, (h_n2, c_n2)= self.lstm(bilstm_output)
        fc_out = self.fc(lstm_output[-1])
        return fc_out

model = BiLSTM()
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
for epoch in range(50):
    cost = 0
    for input_batch, target_batch in loader:
        pred = model(input_batch)
        loss = criterion(pred, target_batch)
        cost += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Epoch: %d,  loss: %.5f " % (epoch, cost))

# 测试
test_text = 'lov'
test_batch = [[word2index[i] for i in test_text]]
test_batch = torch.LongTensor(test_batch)
out = model(test_batch)
predict = torch.max(out, 1)[1].item()
print(test_text,"后一个字母为:", index2word[predict])

 

本文标签: 文本简单PytorchBiLSTM