当前位置: 首页 > news >正文

【Pytorch】语言模型上的动态量化

目录

■导言

①定义模型

②加载文本数据

③加载预训练模型

④测试动态量化

■结论



■导言

量化涉及将模型的权重和激活从float转换为int,这可以导致更小的模型大小和更快的推理,并且只对准确性造成很小的影响。

本文将把最简单的量化形式-动态量化-应用于基于lstm的下一个单词预测模型,与PyTorch示例中的单词语言模型非常相似。

# importsimport osfrom io import openimport timeimport torchimport torch.nn as nnimport torch.nn.functional as F

①定义模型

定义LSTM模型体系结构,遵循单词语言模型示例中的模型。

# 定义一个包含编码器、循环层和解码器的LSTM模型
class LSTMModel(nn.Module):"""容器模块,包含编码器、递归模块(LSTM)和解码器。"""def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):"""初始化LSTM模型。参数:ntoken (int): 词汇表大小,即输入数据的类别数。ninp (int): 输入层维度,即词嵌入的维度。nhid (int): 隐藏层维度,即LSTM中每个单元的隐藏状态维度。nlayers (int): LSTM的层数。dropout (float): Dropout率,用于防止过拟合,默认为0.5。"""super(LSTMModel, self).__init__()# Dropout层,用于在训练过程中随机丢弃一部分神经元,防止过拟合self.drop = nn.Dropout(dropout)# 编码器层:将输入的离散词索引转换为密集向量表示(词嵌入)self.encoder = nn.Embedding(ntoken, ninp)# LSTM层:负责处理序列数据,提取时序特征self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)# 解码器层:将LSTM输出的隐藏状态映射回词汇表空间,预测下一个词self.decoder = nn.Linear(nhid, ntoken)# 初始化模型参数self.init_weights()# 保存隐藏层维度和网络层数供后续使用self.nhid = nhidself.nlayers = nlayersdef init_weights(self):"""初始化模型参数,采用均匀分布进行初始化"""initrange = 0.1# 对编码器的权重进行初始化self.encoder.weight.data.uniform_(-initrange, initrange)# 解码器偏置初始化为0self.decoder.bias.data.zero_()# 解码器权重也使用相同的均匀分布初始化self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, input, hidden):"""前向传播函数。参数:input (Tensor): 当前批次的输入数据,形状为(seq_len, batch_size)。hidden (tuple): 初始隐藏状态(h0, c0)。返回:decoded (Tensor): 输出结果,表示每个时间步的预测词概率。hidden (tuple): 更新后的隐藏状态。"""# 将输入通过编码器转化为嵌入向量,并应用Dropoutemb = self.drop(self.encoder(input))# 将嵌入后的数据传入LSTM进行处理,得到输出和新的隐藏状态output, hidden = self.rnn(emb, hidden)# 对LSTM的输出应用Dropoutoutput = self.drop(output)# 通过解码器将输出映射到词汇表空间,得到每个词的概率分布decoded = self.decoder(output)return decoded, hiddendef init_hidden(self, bsz):"""初始化隐藏状态(h0 和 c0),通常在每轮开始时调用。参数:bsz (int): batch size,当前批次的大小。返回:tuple: 包含初始隐藏状态的两个张量(h0, c0),形状均为(nlayers, bsz, nhid)"""weight = next(self.parameters())  # 获取第一个参数作为参考,创建相同类型的张量# 初始化为全零张量return (weight.new_zeros(self.nlayers, bsz, self.nhid),weight.new_zeros(self.nlayers, bsz, self.nhid))

这段代码定义了一个基于 LSTM 的语言模型,主要包括以下几个部分:

  1. 编码器:将输入的词索引转换为词向量(Embedding)。
  2. LSTM 层:处理序列数据,提取时序信息。
  3. 解码器:将 LSTM 的输出映射回词汇表空间,用于预测下一个词。
  4. Dropout:防止训练过程中的过拟合。
  5. 参数初始化:对 Embedding 和 Linear 层的参数进行均匀分布初始化。
  6. 隐藏状态管理:提供初始化隐藏状态的方法,便于模型在每次新序列开始时重置记忆。

②加载文本数据

接下来,加载 Wikitext-2 数据集进入Corpus, 再次跟随单词语言模型预处理示例。

# 定义一个词典类,用于建立词语与索引之间的映射
class Dictionary(object):def __init__(self):# word2idx: 词语到索引的映射字典self.word2idx = {}# idx2word: 索引到词语的列表(用于反查)self.idx2word = []def add_word(self, word):"""向词典中添加一个词。如果该词尚未存在,则将其加入列表和字典;否则直接返回已有的索引。参数:word (str): 要添加的词语返回:int: 该词对应的索引"""if word not in self.word2idx:self.idx2word.append(word)self.word2idx[word] = len(self.idx2word) - 1return self.word2idx[word]def __len__(self):"""返回词典中不同词语的数量"""return len(self.idx2word)# 定义语料库类,用于加载和处理文本数据集
class Corpus(object):def __init__(self, path):"""初始化语料库对象,并加载训练、验证和测试数据。参数:path (str): 数据集存放路径"""self.dictionary = Dictionary()# 加载并分词训练集、验证集和测试集self.train = self.tokenize(os.path.join(path, 'train.txt'))self.valid = self.tokenize(os.path.join(path, 'valid.txt'))self.test = self.tokenize(os.path.join(path, 'test.txt'))def tokenize(self, path):"""对指定路径的文本文件进行分词处理。步骤:1. 遍历文件内容,将所有词语加入词典;2. 再次遍历文件,将每句话转换为对应的索引张量;3. 将所有句子的索引拼接成一个大的张量返回。参数:path (str): 文件路径返回:Tensor: 包含整个文件词汇索引的一维张量"""assert os.path.exists(path)# 第一步:读取文本并构建词典with open(path, 'r', encoding="utf8") as f:for line in f:words = line.split() + ['<eos>']  # 每句话以 '<eos>' 结尾表示结束for word in words:self.dictionary.add_word(word)# 第二步:将文本转换为索引张量with open(path, 'r', encoding="utf8") as f:idss = []  # 存储每个句子的索引张量for line in f:words = line.split() + ['<eos>']ids = []for word in words:ids.append(self.dictionary.word2idx[word])idss.append(torch.tensor(ids).type(torch.int64))  # 转换为PyTorch张量ids = torch.cat(idss)  # 将多个句子拼接为一个一维张量return ids# 设置数据集路径
model_data_filepath = 'data/'# 创建语料库对象,加载WikiText-2数据集
corpus = Corpus(model_data_filepath + 'wikitext-2')

这段代码的主要功能是:

  • Dictionary 类:构建词典,实现词语与索引之间的双向映射。
  • Corpus 类:加载并处理文本数据集,包括训练集、验证集和测试集。

使用 tokenize 方法将文本转换为词索引张量。

每句话末尾添加特殊标记 <eos> 表示句子结束。

  • 最终通过 corpus 实例加载 WikiText-2 数据集,可用于后续模型训练或评估。

③加载预训练模型

将一些预训练的权重加载到这个模型架构中。下载所需的预训练模型:

wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth
​

将下载的文件放在数据目录中或相应地更新 model_data_filepath。

ntokens = len(corpus.dictionary)model = LSTMModel(ntoken = ntokens,ninp = 512,nhid = 256,nlayers = 5,
)model.load_state_dict(torch.load(model_data_filepath + 'word_language_model_quantize.pth',map_location=torch.device('cpu'),weights_only=True))model.eval()
print(model)

输出:

LSTMModel((drop): Dropout(p=0.5, inplace=False)(encoder): Embedding(33278, 512)(rnn): LSTM(512, 256, num_layers=5, dropout=0.5)(decoder): Linear(in_features=256, out_features=33278, bias=True))

现在生成一些文本,以确保预训练模型工作正确。

# 初始化一个随机输入,表示起始词的索引(形状为 (1, 1),即一个词)
input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)# 初始化模型的隐藏状态,batch_size = 1
hidden = model.init_hidden(1)# 设置温度参数,用于控制生成结果的随机性:
# 温度越高,输出越随机;温度越低,输出越确定。
temperature = 1.0# 要生成的总词数
num_words = 1000# 打开文件以写入生成的文本
with open(model_data_filepath + 'out.txt', 'w') as outf:with torch.no_grad():  # 不需要计算梯度,加快推理速度for i in range(num_words):# 模型前向传播,得到当前词的输出和新的隐藏状态output, hidden = model(input_, hidden)# 对输出进行处理,得到每个词的概率分布word_weights = output.squeeze().div(temperature).exp().cpu()# 根据概率分布随机选取下一个词的索引word_idx = torch.multinomial(word_weights, 1)[0]# 将当前预测的词作为下一次生成的输入input_.fill_(word_idx)# 将索引转换为实际词语word = corpus.dictionary.idx2word[word_idx]# 写入文件,每20个词换行,否则空格分隔outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))# 每生成100词打印进度if i % 100 == 0:print('| 已生成 {}/{} 个词'.format(i, num_words))# 读取并打印生成的全部文本
with open(model_data_filepath + 'out.txt', 'r') as outf:all_output = outf.read()print(all_output)

输出:

| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'.' b'David' b'<unk>' b'states' b'the' b'album' b'of' b'the' b'key' b'(' b'3' b'@.@' b'2' b'miles' b'per' b'hour' b'destructive' b'73' b'@.@' b'8'
b'm' b')' b'ten' b'years' b'with' b'edible' b'intellectual' b'instruments' b',' b'that' b'was' b'subdivided' b'into' b'an' b'star' b'Hampshires' b'.' b'1981' b',' b'Megan'
b'Room' b'campaigned' b'in' b'1956' b'in' b'Jacob' b'Lake' b'in' b'Floyd' b'of' b'Garden' b'which' b'was' b'introduced' b'by' b'Enuff' b'<unk>' b',' b'<unk>' b'of'
b'a' b'special' b'state' b'opens' b'with' b'a' b'Amusements' b'from' b'North' b'Korea' b'and' b'Temple' b'County' b'.' b'Everson' b',' b'with' b'a' b'Shanghai' b'ultimate'
b'potential' b'play' b',' b'also' b'on' b'October' b'7' b',' b'1848' b',' b'and' b'collaborated' b'up' b'an' b'main' b'(' b'three' b'species' b'hills' b')'
b'.' b'<eos>' b'A.' b'galericulata' b'is' b'a' b'popular' b'white' b'post' b'@-@' b'spored' b',' b'raising' b'in' b'the' b'fourth' b'century' b',' b'and' b'it'
b'was' b'many' b'450' b'claims' b'where' b'there' b'are' b'no' b'official' b'amounts' b'of' b'digital' b'species' b'that' b'are' b'found' b'to' b'be' b'mild' b'.'
b'<eos>' b'<eos>' b'=' b'=' b'Musical' b'records' b'=' b'=' b'<eos>' b'<eos>' b'Gods' b'Narvik' b'a' b'year' b'meets' b'to' b'be' b'found' b'for' b'Hawaii'
b',' b'and' b'the' b'common' b'starling' b'has' b'also' b'described' b'it' b'weak' b'community' b'kings' b'.' b'Continuing' b'regarding' b'<unk>' b',' b'they' b'were' b'partially'
b'deeply' b'distinguished' b'in' b'Ireland' b'.' b'The' b'range' b'of' b'Bullet' b'material' b'is' b'a' b'non' b'@-@' b'imposed' b'planet' b'each' b',' b'even' b'outright'
b'terminology' b',' b'usually' b'available' b',' b'by' b'Russia' b'.' b'The' b'spotless' b'transit' b'along' b'a' b'24' b'Bulletin' b'and' b'could' b'be' b'expected' b'.'
b'There' b'are' b'90' b'mg' b'phallic' b'method' b'<unk>' b'as' b',' b'between' b'2016' b'and' b'1913' b',' b'and' b'first' b'a' b'important' b'sensors' b'since'
b'it' b'has' b'reached' b'a' b'clade' b',' b'distinguished' b',' b'other' b'predators' b'and' b'air' b'long' b'.' b'There' b'are' b'no' b'time' b'of' b'Ceres'
b'in' b'within' b'two' b'years' b'.' b'Five' b'size' b',' b'with' b'35' b'@,@' b'Catholics' b',' b'is' b'proposed' b'to' b'do' b'so' b'unless' b'they'
b'were' b'extremes' b'to' b'ask' b'songs' b'.' b'Chapels' b',' b'until' b'January' b'1988' b',' b'when' b'Dawn' b"'s" b'gravity' b',' b'even' b'eaten' b'by'
b'another' b'shortage' b'.' b'It' b'is' b'possible' b'that' b'group' b'suspended' b',' b'because' b'of' b'the' b'Intermediate' b',' b'related' b'after' b'this' b',' b'they'
b'makes' b'as' b'two' b'as' b'they' b'affect' b'them' b',' b'merged' b'up' b'on' b'DAGs' b'.' b'In' b'Molecular' b'<unk>' b',' b'males' b'may' b'be'
b'occasionally' b'closed' b'in' b'England' b',' b'eye' b'children' b'and' b'unrelated' b'attempts' b'to' b'start' b'variable' b'and' b'coordination' b'.' b'<unk>' b',' b'which' b','
b'unlike' b'other' b'birds' b',' b'mainly' b'fresh' b',' b'<unk>' b'they' b'are' b'<unk>' b'and' b'they' b'are' b'spotted' b'in' b'trees' b'.' b'This' b'kakapo'
b'can' b'be' b'of' b'one' b'side' b'of' b'William' b'<unk>' b'to' b'produce' b'for' b'electron' b'behaviour' b'prove' b'.' b'In' b'the' b'19th' b'century' b','
b'the' b'Skye' b'mural' b'(' b'immigrants' b'has' b'the' b'highest' b'birds' b')' b',' b'to' b'look' b'round' b'or' b'<unk>' b'.' b'With' b'8' b'million'
b'tons' b'(' b'13' b'@.@' b'5' b'in' b')' b'or' b'knowledge' b',' b'their' b'mass' b'can' b'have' b'relocated' b'to' b'the' b'distances' b'of' b'peak'
b'birds' b'.' b'<eos>' b'Mycena' b'galericulata' b'reaches' b'a' b'very' b'apparent' b'amount' b'of' b'scholars' b',' b'as' b'of' b'late' b'November' b'230' b',' b'2006'
b',' b'will' b'buy' b'the' b'overlap' b',' b'especially' b'by' b'<unk>' b'Trypanosoma' b',' b'<unk>' b',' b'and' b'sur' b'Notably' b'\xe2\x80\x93' b'Africa' b'.' b'The'
b'birds' b"'" b'clutch' b'surface' b'weight' b'and' b'kakapo' b'asserts' b'that' b'they' b'are' b'serving' b'as' b'"' b'<unk>' b'"' b',' b'a' b'flock' b'that'
b'feed' b'increase' b'by' b'Call' b'daylight' b'on' b'22' b'August' b'1801' b'.' b'The' b'surviving' b'species' b'of' b'this' b'similar' b'classification' b'is' b'meant' b'to'
b'be' b'(' b'"' b'Always' b'Bird' b'"' b')' b'that' b'could' b'be' b'seen' b'as' b'the' b'country' b'for' b'their' b'sight' b'as' b'it' b'ruler'
b',' b'as' b'it' b'hit' b'that' b'it' b'is' b'probably' b'more' b'time' b'.' b'<eos>' b'Family' b'therapy' b',' b'about' b'3' b'@.@' b'5' b'million'
b'years' b',' b'suggests' b'of' b'1' b'@,@' b'000' b'to' b'eight' b'kilometres' b'(' b'19' b'@.@' b'4' b'\xe2\x80\x93' b'5' b'a.m.' b')' b'long' b','
b'than' b'which' b'may' b'one' b'reopened' b'of' b'the' b'works' b'.' b'It' b'may' b'exist' b'with' b'agriculture' b'such' b'as' b'local' b'areas' b'as' b'such'
b'as' b'a' b'type' b'that' b'may' b'occur' b'back' b'into' b'the' b'night' b',' b'but' b'further' b'does' b'not' b'include' b'military' b'forests' b'.' b'<eos>'
b'Mycena' b'Evans' b'Wi\xc5\x9bniowiecki' b'suggests' b'that' b'"' b'when' b'early' b'practice' b'carried' b'a' b'subspecies' b'strips' b',' b'adding' b'below' b'too' b'distant' b'or' b'to'
b'be' b'the' b'subject' b'of' b'it' b'during' b'Baby' b'terms' b'.' b'"' b'<eos>' b'The' b'Australian' b'starling' b'"' b'The' b'One' b'best' b'name' b'"'
b'have' b'already' b'sold' b'.' b'There' b'is' b'few' b'people' b'in' b'common' b'areas' b'that' b'may' b'be' b'obtained' b'it' b',' b'but' b'sort' b'of'
b'more' b'late' b'recorded' b',' b'hard' b'Ozawa' b',' b'nucleolar' b'bound' b',' b'and' b'Xemnas' b';' b'and' b'is' b'foraging' b'in' b'2000' b'.' b'<eos>'
b'Northern' b'Ireland' b'is' b'very' b'invisible' b'for' b'their' b'state' b',' b'and' b'in' b'<unk>' b',' b'they' b'once' b'have' b'been' b'became' b'inconclusive' b'.'
b'The' b'Maasai' b'which' b'develop' b'near' b'the' b'breeding' b'Trade' b'Island' b'below' b'of' b'theologian' b"'s" b'husband' b'to' b'increase' b',' b'their' b'mouth' b'Colfer'
b'they' b'are' b'.' b'<eos>' b'Scotland' b'blocks' b'by' b'the' b'species' b'and' b'domains' b'preventing' b'tends' b'to' b'work' b'about' b'into' b'the' b'population' b'.'
b'<unk>' b'of' b'volunteers' b'are' b'referring' b'to' b'other' b'ribosomal' b'motifs' b'where' b'other' b'birds' b'have' b'fallen' b'.' b'For' b'this' b'statistics' b'are' b'short'
b'near' b'horns' b',' b'but' b'R\xc3\xa9union' b'has' b'praised' b'it' b'with' b'his' b'main' b'body' b',' b'whereas' b'admire' b'little' b'eye' b'sequences' b'of' b'Bay'
b'177' b',' b'both' b'of' b'which' b'are' b'more' b'small' b'scoring' b'.' b'<eos>' b'Other' b'groups' b'may' b'have' b'completely' b'allowed' b'lobbying' b'to' b'within'
b'the' b'excuse' b'of' b'cameras' b'.' b'Unlike' b'also' b'droplets' b'or' b'<unk>' b',' b'they' b'are' b'loose' b'simultaneously' b'.' b'They' b'have' b'parallel' b'to'
b'pandemic' b'and' b'often' b'beat' b'contact' b'over' b'about' b'60' b'thousand' b'months' b'old' b'.' b'This' b'bird' b'is' b'three' b'more' b'active' b'.' b'Other'
b'other' b'species' b'were' b'<unk>' b'restricted' b'to' b'the' b'standard' b',' b'leaving' b'<unk>' b',' b'particularly' b'slightly' b'Abdi' b'devil' b',' b'resulting' b'on' b'fifty'
b'<unk>' b',' b'known' b'as' b'an' b'bonnet' b'/' b'possibly' b'scale' b',' b'with' b'simplistic' b'and' b'<unk>' b'.' b'Mycena' b'Josip' b'Roth' b',' b'infected'
b'a' b'distinctive' b'item' b'for' b'moving' b'in' b'County' b'City' b'"' b'Mr' b'<unk>' b'"' b'and' b'"' b'patriarch' b'"' b'as' b'"' b'apprehend' b'"'
b',' b'frame' b'expanding' b'capability' b'between' b'behind' b'history' b',' b'red' b',' b'isolated' b',' b'urine' b'(' b'back' b')' b'which' b'are' b'known' b'with'
b'ASCAP' b'.' b'In' b'fact' b',' b'for' b'26' b'%' b'of' b'a' b'year' b'point' b'around' b'65' b'million' b'50' b'(' b'blue' b'long' b')'
b'.' b'These' b'legs' b'John' b'Europos' b'(' b'A' b'<unk>' b')' b'is' b'a' b'recent' b'.' b'In' b'DD' b'eggs' b',' b'it' b'was' b'described'
b'by' b'Hawks' b'as' b'<unk>' b',' b'an' b'pair' b'of' b'charm' b'.' b'It' b'was' b'also' b'hunted' b'that' b'they' b'lived' b'with' b'interior' b'pagan'

它不是GPT-2,但看起来模型已经开始学习语言结构了!

演示动态量化,只需要再定义几个辅助函数:

# 设置BPTT(Backpropagation Through Time)的长度为25,即每次处理25个时间步的数据
bptt = 25# 定义损失函数为交叉熵损失函数,用于计算模型输出与真实标签之间的误差
criterion = nn.CrossEntropyLoss()# 测试时使用的batch size为1
eval_batch_size = 1# 创建测试数据集
def batchify(data, bsz):"""将原始数据分割成多个批次,每个批次大小为bsz。参数:data (Tensor): 原始数据bsz (int): 每个批次的大小返回:Tensor: 分批后的数据"""# 计算可以完整分成多少个批次nbatch = data.size(0) // bsz# 去掉不能整除的部分数据data = data.narrow(0, 0, nbatch * bsz)# 将数据均匀分配到每个批次中,并调整形状为(bsz, -1),然后进行转置和连续化return data.view(bsz, -1).t().contiguous()# 对测试数据进行分批处理
test_data = batchify(corpus.test, eval_batch_size)# 定义获取单个批次数据的函数
def get_batch(source, i):"""从source中提取一个批次的数据和对应的标签。参数:source (Tensor): 输入数据源i (int): 当前批次的起始位置返回:data (Tensor): 输入数据target (Tensor): 对应的标签数据"""# 确定当前批次的时间步长度,不超过bptt且不超过剩余数据长度seq_len = min(bptt, len(source) - 1 - i)# 提取输入数据data = source[i:i+seq_len]# 提取对应的标签数据,并将其展平为一维张量target = source[i+1:i+1+seq_len].reshape(-1)return data, target# 重新封装隐藏状态,使其脱离历史梯度
def repackage_hidden(h):"""包装隐藏状态以断开其历史记录,防止梯度在反向传播时回传到前面的批次。参数:h (Tensor or tuple): 隐藏状态返回:Tensor or tuple: 脱离历史后的隐藏状态"""if isinstance(h, torch.Tensor):return h.detach()else:return tuple(repackage_hidden(v) for v in h)# 定义评估函数
def evaluate(model_, data_source):"""对模型在给定数据源上的表现进行评估。参数:model_ (nn.Module): 训练好的模型data_source (Tensor): 数据源返回:float: 平均损失值"""# 将模型设置为评估模式,禁用dropout等训练专用操作model_.eval()total_loss = 0.# 初始化隐藏状态hidden = model_.init_hidden(eval_batch_size)# 在不计算梯度的情况下进行评估with torch.no_grad():for i in range(0, data_source.size(0) - 1, bptt):# 获取当前批次的数据和标签data, targets = get_batch(data_source, i)# 模型前向传播,得到输出和新的隐藏状态output, hidden = model_(data, hidden)# 重新封装隐藏状态,避免占用过多内存hidden = repackage_hidden(hidden)# 展平输出,以便与标签进行损失计算output_flat = output.view(-1, ntokens)# 累加当前批次的损失total_loss += len(data) * criterion(output_flat, targets).item()# 计算平均损失return total_loss / (len(data_source) - 1)

这段代码的主要功能是:

  1. 数据预处理:通过 batchify 函数将原始数据划分为固定大小的批次,适用于模型输入。
  2. 批量读取get_batch 函数从数据源中提取指定位置的一个批次数据及其标签。
  3. 隐藏状态管理repackage_hidden 函数用于切断隐藏状态的历史记录,防止梯度回传到前面的批次。
  4. 模型评估evaluate 函数对模型在测试数据上的性能进行评估,计算并返回平均损失。

④测试动态量化

最后,可以调用torch.quantization.quantize_dynamic。具体地说,

nn.LSTM和 nn.Linear模块将被量化;

▲指定要将权重转换为int8值。

import torch.quantizationquantized_model = torch.quantization.quantize_dynamic(model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

输出:

LSTMModel((drop): Dropout(p=0.5, inplace=False)(encoder): Embedding(33278, 512)(rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)(decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

模型看起来是一样的,这如何受益?首先,可以看到显著减小了模型尺寸。

def print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print('Size (MB):', os.path.getsize("temp.p")/1e6)os.remove('temp.p')print_size_of_model(model)
print_size_of_model(quantized_model)

输出:

Size (MB): 113.944455
Size (MB): 79.738939

其次,可以看到更快的推理时间,评估损失没有差异。

注意:将线程数设置为单个线程比较的线程数,因为量化模型运行单线程。

# 设置 PyTorch 使用单个线程,以便更准确地测量模型推理时间
torch.set_num_threads(1)def time_model_evaluation(model, test_data):"""评估模型在测试数据上的性能,并统计推理时间。参数:model (nn.Module): 要评估的模型test_data (Tensor): 测试数据返回:打印模型损失和所用时间"""s = time.time()  # 记录开始时间loss = evaluate(model, test_data)  # 调用评估函数计算损失elapsed = time.time() - s  # 计算耗时# 打印损失值和耗时(秒)print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))# 分别对原始模型和量化模型进行评估并计时
time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

这段代码的主要目的是:

  1. 限制线程数:通过 torch.set_num_threads(1) 确保模型推理只使用一个线程,避免多线程影响时间测量准确性。
  2. 定义评估计时函数 time_model_evaluation
    • 对输入模型在测试数据上做评估;
    • 输出交叉熵损失和推理耗时。
  3. 对比评估两个模型
    • 原始浮点模型 model
    • 量化后的模型 quantized_model

这通常用于比较量化前后模型的推理速度精度损失,是模型压缩与优化中的常见做法。

输出

loss: 5.167
elapsed time (seconds): 198.3
loss: 5.168
elapsed time (seconds): 111.4

在 MacBook Pro上本地运行此功能,无需量化,推理大约需要 200 秒, 量化只需要大约100秒。

■结论

动态量化可以是一种减小模型尺寸的简单方法,而只对准确性带来有限的影响。

至此,本文分享的内容就结束了。

http://www.lqws.cn/news/535195.html

相关文章:

  • 供应链管理:主要生产计划类型及其相关信息
  • Solidity学习 - 认识Solidity合约结构
  • GitLab 18.1 发布 Runner、无效的个人访问令牌查看等功能,可升级体验!
  • 一分钟了解Transformer
  • 深入了解 AWS EventBridge
  • 无人机螺旋桨机械能模块解析
  • 深入解析前端 Meta 标签:HTML 的隐形守护者与功能大师
  • cudaStreamCreateWithPriority和cudaDeviceGetStreamPriorityRange
  • 基于vue框架的二手图书零售系统q7jqy(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。
  • 279. 完全平方数
  • 2025 Java开发生态全景图:云原生、AI与性能优化的技术融合
  • 用 Spark 优化亿级用户画像计算:Delta Lake 增量更新策略详解
  • flutter结合ai工具(其他语言通用)
  • 【CMake基础入门教程】第六课:构建静态库 / 动态库 与安装规则(install)
  • Linux命令:内置命令与外部命令的本质区别
  • MongoDB
  • jupyter notebook Kernel Restarting内核崩溃的解决
  • Linux命令与脚本:高效系统管理的双刃剑
  • 用户中心配置(资源、角色、用户配置)
  • 机器学习在智能农业中的创新应用与未来趋势
  • 【javascript】this关键字
  • vue + vue-router写登陆验证的同步方法和异步方法,及页面组件的分离和后端代码
  • Unity Netcode自定义数据传输——结构体及其序列化
  • .NET测试工具Parasoft dotTEST内置安全标准,编码合规更高效
  • 基于STM32的智能书房系统的设计
  • SpringBoot定时任务 - Timer实现方式
  • 算法打卡 day4
  • 大数据赋能智慧城市:从数据洪流到科学规划的“智慧之匙”
  • Leetcode百题斩-DP
  • 全面学习 OpenAI API:从 Python 教程到 API Key 使用详解,快速上手调用和部署