NLP(十二):word2vec+siamese-BiLSTM计算文本相似度

一、模型my_bilstm.py

import torch
from torch import nn

class SiameseLSTM(nn.Module):
    def __init__(self, input_size):
        super(SiameseLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=64, num_layers=1, batch_first=True, bidirectional = True)
        self.fc = nn.Sequential(
            nn.Linear(256,200),
            nn.LeakyReLU(inplace = True),
            nn.Linear(200,1),

        )
    def forward(self, data1, data2):
        out1, (h1, c1) = self.lstm(data1)
        out2, (h2, c2) = self.lstm(data2)
        pre1 = out1[:, -1, :]
        pre2 = out2[:, -1, :]
        pre = torch.cat([pre1,pre2],dim=1)
        out = self.fc(pre)
        return out

if __name__ == '__main__':
    d1 = torch.rand(2, 16, 128)
    d2 = torch.rand(2, 16, 128)
    model = SiameseLSTM(128)
    model(d1, d2)

二、数据集my_dataset.py

import torch.utils.data as data


class MyDataset(data.Dataset):
    def __init__(self, texta, textb, label):
        self.texta = texta
        self.textb = textb
        self.label = label

    def __getitem__(self, item):
        texta = self.texta[item]
        textb = self.textb[item]
        label = self.label[item]
        return texta, textb, label
    def __len__(self):
        return len(self.texta)

三、词嵌入

my_word2vec.py
from gensim.models.fasttext import FastText
import torch
import numpy as np
import os

class WordEmbedding(object):
    def __init__(self):
        parent_path = os.path.split(os.path.realpath(__file__))[0]
        self.root = parent_path[:parent_path.find("models")]  # E:personassemantics
        self.word_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "word_fasttext.model")
        self.char_fasttext = os.path.join(self.root, "checkpoints", "word2vec", "char_fasttext.model")
        self.model = FastText.load(self.char_fasttext)

    def sentenceTupleToEmbedding(self, data1, data2):
        aCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data1])
        bCutListMaxLen = max([len(list(str(sentence_a))) for sentence_a in data2])
        maxLen = max(aCutListMaxLen,bCutListMaxLen)
        seq_len = maxLen
        a = self.sqence_vec(data1, seq_len) #batch_size, sqence, embedding
        b = self.sqence_vec(data2, seq_len)
        return torch.FloatTensor(a), torch.FloatTensor(b)
    def sqence_vec(self, data, seq_len):
        data_a_vec = []
        for sequence_a in data:
            sequence_vec = []  # sequence * 128
            for word_a in list(str(sequence_a)):
                if word_a in self.model.wv:
                    sequence_vec.append(self.model.wv[word_a])
            sequence_vec = np.array(sequence_vec)
            add = np.zeros((seq_len - sequence_vec.shape[0], 128))
            sequenceVec = np.vstack((sequence_vec, add))
            data_a_vec.append(sequenceVec)
        a_vec = np.array(data_a_vec)
        return a_vec

if __name__ == '__main__':
    word = WordEmbedding()
    data1 = ("浙江杭州富阳区银湖街黄先生的外卖","浙江杭州富阳区银湖街黄先生的外卖")
    data2 = ("富阳区浙江富阳区银湖街道新常村","浙江杭州富阳区银湖街黄先生的外卖")
    a, b = word.sentenceTupleToEmbedding(data1, data2)
    print(a.shape)
    print(b)

四、运行类

run__bilstm.py
import torch
import os
from torch.utils.data import DataLoader
from my_dataset import MyDataset
import pandas as pd
import numpy as np
from my_bilstm import SiameseLSTM
import torch.nn as nn
from my_word2vec import WordEmbedding


class RunBiLSTM():
    def __init__(self):
        self.learning_rate = 0.001
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        parent_path = os.path.split(os.path.realpath(__file__))[0]
        self.root = parent_path[:parent_path.find("models")]  # E:personassemantics
        self.train_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "train.csv")
        self.val_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "val.csv")
        self.test_path = os.path.join(self.root, "datas", "bert_data", "sim_data", "test.csv")
        self.batch_size =256
        self.epoch = 50
        self.criterion = nn.BCEWithLogitsLoss().to(self.device)
        self.word = WordEmbedding()
        self.check_point = os.path.join(self.root, "checkpoints", "char_bilstm", "char_bilstm.pth")

    def get_loader(self, path):
        data = pd.read_csv(path, sep="	")
        d1, d2, y = data["s1"], data["s2"], list(data["y"])

        dataset = MyDataset(d1, d2, torch.LongTensor(y))
        data_iter = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        return data_iter

    def binary_acc(self, preds, y):
        preds = torch.round(torch.sigmoid(preds))
        correct = torch.eq(preds, y).float()
        acc = correct.sum() / len(correct)
        return acc


    def train(self, mynet, train_iter, optimizer, criterion, epoch, device):
        avg_acc = []
        avg_loss = []
        mynet.train()
        for batch_id, (data1, data2, label) in enumerate(train_iter):
            try:
                a, b = self.word.sentenceTupleToEmbedding(data1, data2)
            except Exception as e:
                print("错误")
            a, b, label = a.to(device), b.to(device), label.to(device)
            distence = mynet(a, b)
            distence = distence.squeeze(1)


            loss = criterion(distence, label.float())
            acc = self.binary_acc(distence, label.float()).item()
            avg_acc.append(acc)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch_id % 100 == 0:
                print("轮数:", epoch, "batch: ", batch_id, "训练损失:", loss.item(), "准确率:", acc)
            avg_loss.append(loss.item())
        avg_acc = np.array(avg_acc).mean()
        avg_loss = np.array(avg_loss).mean()
        print('train acc:', avg_acc)
        print("train loss", avg_loss)

    def eval(self, mynet, test_iter, criteon, epoch, device):
        mynet.eval()
        avg_acc = []
        avg_loss = []
        with torch.no_grad():
            for batch_id, (data1, data2, label) in enumerate(test_iter):
                try:
                    a, b = self.word.sentenceTupleToEmbedding(data1, data2)
                except Exception as e:
                    continue

                a, b, label = a.to(device), b.to(device), label.to(device)
                distence = mynet(a, b)
                distence = distence.squeeze(1)
                loss = criteon(distence, label.float())
                acc = self.binary_acc(distence, label.float()).item()
                avg_acc.append(acc)
                avg_loss.append(loss.item())
                if batch_id>50:
                    break
        avg_acc = np.array(avg_acc).mean()
        avg_loss = np.array(avg_loss).mean()
        print('>>test acc:', avg_acc)
        print(">>test loss:", avg_loss)
        return (avg_acc, avg_loss)

    def run_train(self):
        model = SiameseLSTM(128).to(self.device)
        max_acc = 0
        train_iter = self.get_loader(self.train_path)
        val_iter = self.get_loader(self.val_path)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

        for epoch in range(self.epoch):
            self.train(model, train_iter, optimizer, self.criterion, epoch, self.device)
            eval_acc, eval_loss = self.eval(model, val_iter, self.criterion, epoch, self.device)
            if eval_acc > max_acc:
                print("save model")
                torch.save(model.state_dict(), self.check_point)
                max_acc = eval_acc

    def test(self):
        da = self.get_loader(self.val_path)
        for batch_id, (data1, data2, label) in enumerate(da):
            print(label)
            break

if __name__ == '__main__':
    RunBiLSTM().run_train()

五、实验结果

轮数: 32 batch:  0 训练损失: 0.1690833866596222 准确率: 0.91796875
轮数: 32 batch:  100 训练损失: 0.16252592206001282 准确率: 0.9296875
轮数: 32 batch:  200 训练损失: 0.16619177162647247 准确率: 0.9375
轮数: 32 batch:  300 训练损失: 0.1599806845188141 准确率: 0.9453125
train acc: 0.9276657348242812
train loss 0.18327004048294915
>>test acc: 0.9079337269067764
>>test loss: 0.24136937782168388

train acc: 0.9688872803514377
train loss 0.08603085891697734
>>test acc: 0.9298221915960312
>>test loss: 0.22169270366430283