mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			84 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Markdown
		
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Markdown
		
	
	
	
| ## data.DataLoader
 | ||
| ```python
 | ||
| class DataSet
 | ||
| ```
 | ||
| DataLoader数据加载器,支持数据批处理和随机采样
 | ||
| 
 | ||
| ---
 | ||
| ### `DataLoader(dataset, batch_size, shuffle, num_workers)`
 | ||
| 创建一个DataLoader
 | ||
| 
 | ||
| 参数:
 | ||
| - `dataset:DataSet` 数据集实例
 | ||
| - `batch_size:int` 批处理大小
 | ||
| - `shuffle:bool` 打乱数据集标记,默认为True
 | ||
| - `num_workers:int` 线程数,默认为0
 | ||
| 
 | ||
| 返回:数据加载器
 | ||
| 
 | ||
| 返回类型:`DataLoader`
 | ||
| 
 | ||
| ---
 | ||
| ### `iter_number`
 | ||
| 
 | ||
| 返回总迭代次数,当剩余的数据在一个批次大小中没有满仍然会被加载
 | ||
| 
 | ||
| 属性类型:只读
 | ||
| 
 | ||
| 类型:`int`
 | ||
| 
 | ||
| ---
 | ||
| ### `size`
 | ||
| 
 | ||
| 获取数据集大小
 | ||
| 
 | ||
| 属性类型:只读
 | ||
| 
 | ||
| 类型:`int`
 | ||
| 
 | ||
| ---
 | ||
| ### `reset()`
 | ||
| 
 | ||
| 重置数据加载器,数据加载器每次用完后都需要重置
 | ||
| 
 | ||
| 返回:`None`
 | ||
| 
 | ||
| 返回类型:`None`
 | ||
| 
 | ||
| ---
 | ||
| ### `next()`
 | ||
| 
 | ||
| 在数据集中获取批量数据
 | ||
| 
 | ||
| 返回:`([Var], [Var])` 两组数据,第一组为输入数据,第二组为结果数据
 | ||
| 
 | ||
| 返回类型:`tuple`
 | ||
| 
 | ||
| 示例:
 | ||
| 
 | ||
| ```python
 | ||
| train_dataset = MnistDataset(True)
 | ||
| test_dataset = MnistDataset(False)
 | ||
| train_dataloader = data.DataLoader(train_dataset, batch_size = 64, shuffle = True)
 | ||
| test_dataloader = data.DataLoader(test_dataset, batch_size = 100, shuffle = False)
 | ||
| ...
 | ||
| # use in training
 | ||
| def train_func(net, train_dataloader, opt):
 | ||
|     """train function"""
 | ||
|     net.train(True)
 | ||
|     # need to reset when the data loader exhausted
 | ||
|     train_dataloader.reset()
 | ||
|     t0 = time.time()
 | ||
|     for i in range(train_dataloader.iter_number):
 | ||
|         example = train_dataloader.next()
 | ||
|         input_data = example[0]
 | ||
|         output_target = example[1]
 | ||
|         data = input_data[0]  # which input, model may have more than one inputs
 | ||
|         label = output_target[0]  # also, model may have more than one outputs
 | ||
|         predict = net.forward(data)
 | ||
|         target = expr.one_hot(expr.cast(label, expr.int), 10, 1, 0)
 | ||
|         loss = nn.loss.cross_entropy(predict, target)
 | ||
|         opt.step(loss)
 | ||
|         if i % 100 == 0:
 | ||
|             print("train loss: ", loss.read())
 | ||
| ``` |