transformer微调预训练模型

先写到前面:

1. 加载数据集

  • 这里我主要是想要搞清楚魔术方法的问题:

基本概念:

  • 定义:在 Python 中,以双下划线开头和结尾的方法(例如 __xxx__)被称为特殊方法(也称为魔术方法或 dunder 方法,“dunder” 是 “double underscore” 的缩写)。
  • Python 解释器自动调用(特殊性质):特殊方法会在特定的操作或事件发生时被 Python 解释器自动调用,而不需要显式地调用这些方法。例如,当使用 len() 函数时,Python 解释器会自动调用对象的 __len__ 方法;当使用索引访问对象(如 obj[index])时,会自动调用对象的 __getitem__ 方法。

关于除了__init__方法以外的其他方法的调用:

__getitem__ 方法在索引访问时被调用的原理

Python 为了实现统一且灵活的对象操作方式,定义了一系列的特殊方法(魔术方法)。这些特殊方法允许自定义类的对象表现得像 Python 内置的数据类型(如列表、字典等)一样。

当你使用 obj[index] 这种索引访问语法时,Python 解释器会自动查找 obj 对象是否实现了 __getitem__ 方法。如果实现了,就会调用该方法并将 index 作为参数传递给它。这是 Python 语言的一种约定和机制,通过这种方式,你可以让自定义类支持索引访问操作。

以下是一个简单的示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class MyList:
def __init__(self, data):
self.data = data

def __getitem__(self, index):
print(f"__getitem__ 方法被调用,索引为 {index}")
return self.data[index]

# 创建 MyList 类的实例
my_list = MyList([10, 20, 30])

# 使用索引访问操作,会触发 __getitem__ 方法
result = my_list[1]
print(result)

在这个例子中,当执行 my_list[1] 时,Python 解释器检测到 my_list 对象实现了 __getitem__ 方法,就会自动调用该方法,并将 1 作为参数传递给它。

  • 对于我们要学习的Dataset类,

关于Dataset类的基本使用流程:

1. 导入必要的库

1
2
from torch.utils.data import Dataset
import json
  • torch.utils.data.Dataset:这是 PyTorch 中用于表示数据集的抽象基类,自定义的数据集类通常需要继承这个类,以便能够使用 PyTorch 提供的数据集处理和数据加载工具。
  • json:Python 标准库中的 json 模块,用于处理 JSON 格式的数据。

2. 定义自定义数据集类 AFQMC

1
2
3
class AFQMC(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
  • AFQMC 类继承自 Dataset 类。
  • __init__ 方法是类的构造函数,在创建 AFQMC 类的实例时会被调用。它接收一个参数 data_file,表示数据文件的路径。在方法内部,调用了 self.load_data(data_file) 方法来加载数据,并将加载的数据存储在 self.data 中。
1
2
3
4
5
6
7
def load_data(self, data_file):
Data = {}
with open(data_file, 'rt') as f:
for idx, line in enumerate(f):
sample = json.loads(line.strip())
Data[idx] = sample
return Data
  • load_data 方法是自定义的方法,用于从指定的 JSON 文件中加载数据。
  • 它首先创建一个空字典 Data,然后逐行读取文件。对于每一行,使用 json.loads 方法将其解析为 JSON 对象,并将其存储在字典 Data 中,键为行索引 idx
  • 最后返回加载好的数据字典。
1
2
def __len__(self):
return len(self.data)
  • __len__ 方法是必须实现的方法,它返回数据集的长度,即数据样本的数量。在这个例子中,返回的是 self.data 字典的长度。
1
2
def __getitem__(self, idx):
return self.data[idx]
  • __getitem__ 方法也是必须实现的方法,它用于根据给定的索引 idx 从数据集中获取一个样本。在这个例子中,直接从 self.data 字典中获取对应索引的样本并返回。

3. 创建数据集实例并打印样本

1
2
3
4
train_data = AFQMC('data/afqmc_public/train.json')
valid_data = AFQMC('data/afqmc_public/dev.json')

print(train_data[0])
  • 创建了两个 AFQMC 类的实例 train_datavalid_data,分别加载训练集和验证集的数据。
  • 打印训练集的第一个样本。

Dataset 在数据微调中的作用

在深度学习模型微调(fine-tuning)过程中,Dataset 类起着至关重要的作用:

  • 数据组织Dataset 类可以将原始数据(如文本、图像等)组织成模型可以处理的格式。通过自定义 __getitem__ 方法,可以灵活地对数据进行预处理,如分词、编码等。
  • 数据加载Dataset 类可以与 DataLoader 类配合使用,实现数据的批量加载和并行处理。DataLoader 可以根据需要对数据进行打乱、分批等操作,提高数据加载的效率。
  • 代码复用:自定义的 Dataset 类可以在不同的实验和项目中复用,方便对数据进行统一管理和处理。

Dataset 继承的方法体现

AFQMC 类继承自 Dataset 类,虽然 Dataset 是一个抽象基类,没有具体的实现代码,但它定义了一些必须实现的方法,这些方法在 AFQMC 类中得到了具体实现:

  • __len__ 方法:用于返回数据集的长度。在使用 DataLoader 加载数据时,DataLoader 会调用这个方法来确定数据集的大小。
  • __getitem__ 方法:用于根据索引获取数据集中的一个样本。DataLoader 在加载数据时,会通过这个方法逐批获取数据。

通过实现这两个方法,AFQMC 类就符合了 Dataset 类的接口规范,可以与 PyTorch 提供的其他数据处理和模型训练工具无缝集成。

DatasetDataloader类之间的配合:

DataLoader 是 PyTorch 中用于批量加载数据的工具,它与自定义的 Dataset 类(如 LawDataset)紧密配合,以高效地为模型训练或推理提供数据。具体的配合过程如下:

1. 初始化阶段

当你创建 DataLoader 实例时,需要将 Dataset 实例作为参数传入,例如:

1
2
3
from torch.utils.data import DataLoader
# 假设已经定义了 LawDataset 类并创建了 dataset 实例
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

这里的 batch_size 表示每个批次包含的样本数量,shuffle 表示是否在每个 epoch 开始时打乱数据集。

2. 数据加载阶段

当你使用 for 循环遍历 DataLoader 时,DataLoader 会根据 batch_size 的设置,多次调用 Dataset 类的 __getitem__ 方法来获取样本。具体步骤如下:

  • DataLoader 首先会确定当前批次需要获取的样本索引。
  • 对于每个索引,DataLoader 会调用 Dataset 实例的 __getitem__ 方法,传入该索引,从而获取对应的样本。
  • 当获取到 batch_size 个样本后,DataLoader 会将这些样本组合成一个批次,并将其返回给循环。

例如,当 batch_size = 2 时,DataLoader 会先调用 __getitem__(0) 得到第一个样本,再调用 __getitem__(1) 得到第二个样本,然后将这两个样本组合成一个批次。

补充几点:

__getitem__ 方法每次只处理一个文本的原因

__getitem__ 方法设计为每次只处理一个索引对应的文本,这是为了保证代码的简洁性和灵活性。通过将数据处理逻辑封装在 __getitem__ 方法中,DataLoader 可以根据需要多次调用该方法,从而实现批量数据的加载。如果 __getitem__ 方法要处理多个文本,会增加方法的复杂度,并且不利于数据的随机访问和批量处理。

__getitem__ 方法中 self.texts[idx] 的调用逻辑

__getitem__ 方法中,self.texts[idx] 只是获取 self.texts 列表中索引为 idx 的文本元素,它本身并不会触发整个 __getitem__ 方法。__getitem__ 方法是在 DataLoader 调用时被触发的,DataLoader 会传入一个索引 idx,然后执行 __getitem__ 方法中的代码,self.texts[idx] 只是 __getitem__ 方法中的一个操作步骤。

以下是一个简单的示例代码,展示了 DataLoaderDataset 类的配合过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
print(f"Getting item at index {idx}")
return self.data[idx]

# 示例数据
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

for batch in dataloader:
print(f"Batch: {batch}")

在这个示例中,DataLoader 每次会调用 MyDataset__getitem__ 方法两次(因为 batch_size = 2),获取两个样本并组合成一个批次,然后将批次返回给循环进行处理。

综上所述,DataLoader 通过多次调用 Dataset 类的 __getitem__ 方法来实现批量数据的加载,而 __getitem__ 方法每次只处理一个索引对应的样本,保证了代码的简洁性和灵活性。