transformer微调预训练模型
transformer微调预训练模型
先写到前面:
- 本文只是对网上transformer学习的文档的注解,并非完整教程
- 相关文档链接:transformer快速入门教程
1. 加载数据集
- 这里我主要是想要搞清楚魔术方法的问题:
基本概念:
- 定义:在 Python 中,以双下划线开头和结尾的方法(例如
__xxx__)被称为特殊方法(也称为魔术方法或 dunder 方法,“dunder” 是 “double underscore” 的缩写)。 - Python 解释器自动调用(特殊性质):特殊方法会在特定的操作或事件发生时被 Python 解释器自动调用,而不需要显式地调用这些方法。例如,当使用
len()函数时,Python 解释器会自动调用对象的__len__方法;当使用索引访问对象(如obj[index])时,会自动调用对象的__getitem__方法。
关于除了__init__方法以外的其他方法的调用:
__getitem__ 方法在索引访问时被调用的原理
Python 为了实现统一且灵活的对象操作方式,定义了一系列的特殊方法(魔术方法)。这些特殊方法允许自定义类的对象表现得像 Python 内置的数据类型(如列表、字典等)一样。
当你使用 obj[index] 这种索引访问语法时,Python 解释器会自动查找 obj 对象是否实现了 __getitem__ 方法。如果实现了,就会调用该方法并将 index 作为参数传递给它。这是 Python 语言的一种约定和机制,通过这种方式,你可以让自定义类支持索引访问操作。
以下是一个简单的示例:
1 | class MyList: |
在这个例子中,当执行 my_list[1] 时,Python 解释器检测到 my_list 对象实现了 __getitem__ 方法,就会自动调用该方法,并将 1 作为参数传递给它。
- 对于我们要学习的
Dataset类,
关于Dataset类的基本使用流程:
1. 导入必要的库
1 | from torch.utils.data import Dataset |
torch.utils.data.Dataset:这是 PyTorch 中用于表示数据集的抽象基类,自定义的数据集类通常需要继承这个类,以便能够使用 PyTorch 提供的数据集处理和数据加载工具。json:Python 标准库中的json模块,用于处理 JSON 格式的数据。
2. 定义自定义数据集类 AFQMC
1 | class AFQMC(Dataset): |
AFQMC类继承自Dataset类。__init__方法是类的构造函数,在创建AFQMC类的实例时会被调用。它接收一个参数data_file,表示数据文件的路径。在方法内部,调用了self.load_data(data_file)方法来加载数据,并将加载的数据存储在self.data中。
1 | def load_data(self, data_file): |
load_data方法是自定义的方法,用于从指定的 JSON 文件中加载数据。- 它首先创建一个空字典
Data,然后逐行读取文件。对于每一行,使用json.loads方法将其解析为 JSON 对象,并将其存储在字典Data中,键为行索引idx。 - 最后返回加载好的数据字典。
1 | def __len__(self): |
__len__方法是必须实现的方法,它返回数据集的长度,即数据样本的数量。在这个例子中,返回的是self.data字典的长度。
1 | def __getitem__(self, idx): |
__getitem__方法也是必须实现的方法,它用于根据给定的索引idx从数据集中获取一个样本。在这个例子中,直接从self.data字典中获取对应索引的样本并返回。
3. 创建数据集实例并打印样本
1 | train_data = AFQMC('data/afqmc_public/train.json') |
- 创建了两个
AFQMC类的实例train_data和valid_data,分别加载训练集和验证集的数据。 - 打印训练集的第一个样本。
Dataset 在数据微调中的作用
在深度学习模型微调(fine-tuning)过程中,Dataset 类起着至关重要的作用:
- 数据组织:
Dataset类可以将原始数据(如文本、图像等)组织成模型可以处理的格式。通过自定义__getitem__方法,可以灵活地对数据进行预处理,如分词、编码等。 - 数据加载:
Dataset类可以与DataLoader类配合使用,实现数据的批量加载和并行处理。DataLoader可以根据需要对数据进行打乱、分批等操作,提高数据加载的效率。 - 代码复用:自定义的
Dataset类可以在不同的实验和项目中复用,方便对数据进行统一管理和处理。
Dataset 继承的方法体现
AFQMC 类继承自 Dataset 类,虽然 Dataset 是一个抽象基类,没有具体的实现代码,但它定义了一些必须实现的方法,这些方法在 AFQMC 类中得到了具体实现:
__len__方法:用于返回数据集的长度。在使用DataLoader加载数据时,DataLoader会调用这个方法来确定数据集的大小。__getitem__方法:用于根据索引获取数据集中的一个样本。DataLoader在加载数据时,会通过这个方法逐批获取数据。
通过实现这两个方法,AFQMC 类就符合了 Dataset 类的接口规范,可以与 PyTorch 提供的其他数据处理和模型训练工具无缝集成。
Dataset和Dataloader类之间的配合:
DataLoader 是 PyTorch 中用于批量加载数据的工具,它与自定义的 Dataset 类(如 LawDataset)紧密配合,以高效地为模型训练或推理提供数据。具体的配合过程如下:
1. 初始化阶段
当你创建 DataLoader 实例时,需要将 Dataset 实例作为参数传入,例如:
1 | from torch.utils.data import DataLoader |
这里的 batch_size 表示每个批次包含的样本数量,shuffle 表示是否在每个 epoch 开始时打乱数据集。
2. 数据加载阶段
当你使用 for 循环遍历 DataLoader 时,DataLoader 会根据 batch_size 的设置,多次调用 Dataset 类的 __getitem__ 方法来获取样本。具体步骤如下:
DataLoader首先会确定当前批次需要获取的样本索引。- 对于每个索引,
DataLoader会调用Dataset实例的__getitem__方法,传入该索引,从而获取对应的样本。 - 当获取到
batch_size个样本后,DataLoader会将这些样本组合成一个批次,并将其返回给循环。
例如,当 batch_size = 2 时,DataLoader 会先调用 __getitem__(0) 得到第一个样本,再调用 __getitem__(1) 得到第二个样本,然后将这两个样本组合成一个批次。
补充几点:
__getitem__ 方法每次只处理一个文本的原因
__getitem__ 方法设计为每次只处理一个索引对应的文本,这是为了保证代码的简洁性和灵活性。通过将数据处理逻辑封装在 __getitem__ 方法中,DataLoader 可以根据需要多次调用该方法,从而实现批量数据的加载。如果 __getitem__ 方法要处理多个文本,会增加方法的复杂度,并且不利于数据的随机访问和批量处理。
__getitem__ 方法中 self.texts[idx] 的调用逻辑
在 __getitem__ 方法中,self.texts[idx] 只是获取 self.texts 列表中索引为 idx 的文本元素,它本身并不会触发整个 __getitem__ 方法。__getitem__ 方法是在 DataLoader 调用时被触发的,DataLoader 会传入一个索引 idx,然后执行 __getitem__ 方法中的代码,self.texts[idx] 只是 __getitem__ 方法中的一个操作步骤。
以下是一个简单的示例代码,展示了 DataLoader 与 Dataset 类的配合过程:
1 | from torch.utils.data import Dataset, DataLoader |
在这个示例中,DataLoader 每次会调用 MyDataset 的 __getitem__ 方法两次(因为 batch_size = 2),获取两个样本并组合成一个批次,然后将批次返回给循环进行处理。
综上所述,DataLoader 通过多次调用 Dataset 类的 __getitem__ 方法来实现批量数据的加载,而 __getitem__ 方法每次只处理一个索引对应的样本,保证了代码的简洁性和灵活性。



