MNN/docs/pymnn/DataLoader.md

84 lines
2.0 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

## data.DataLoader
```python
class DataSet
```
DataLoader数据加载器支持数据批处理和随机采样
---
### `DataLoader(dataset, batch_size, shuffle, num_workers)`
创建一个DataLoader
参数:
- `dataset:DataSet` 数据集实例
- `batch_size:int` 批处理大小
- `shuffle:bool` 打乱数据集标记默认为True
- `num_workers:int` 线程数默认为0
返回:数据加载器
返回类型:`DataLoader`
---
### `iter_number`
返回总迭代次数,当剩余的数据在一个批次大小中没有满仍然会被加载
属性类型:只读
类型:`int`
---
### `size`
获取数据集大小
属性类型:只读
类型:`int`
---
### `reset()`
重置数据加载器,数据加载器每次用完后都需要重置
返回:`None`
返回类型:`None`
---
### `next()`
在数据集中获取批量数据
返回:`([Var], [Var])` 两组数据,第一组为输入数据,第二组为结果数据
返回类型:`tuple`
示例:
```python
train_dataset = MnistDataset(True)
test_dataset = MnistDataset(False)
train_dataloader = data.DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_dataloader = data.DataLoader(test_dataset, batch_size = 100, shuffle = False)
...
# use in training
def train_func(net, train_dataloader, opt):
"""train function"""
net.train(True)
# need to reset when the data loader exhausted
train_dataloader.reset()
t0 = time.time()
for i in range(train_dataloader.iter_number):
example = train_dataloader.next()
input_data = example[0]
output_target = example[1]
data = input_data[0] # which input, model may have more than one inputs
label = output_target[0] # also, model may have more than one outputs
predict = net.forward(data)
target = expr.one_hot(expr.cast(label, expr.int), 10, 1, 0)
loss = nn.loss.cross_entropy(predict, target)
opt.step(loss)
if i % 100 == 0:
print("train loss: ", loss.read())
```