跳转到主要内容

Merlin Dataloader

项目描述

Merlin Dataloader

PyPI - Python Version PyPI version shields.io GitHub License Documentation

The merlin-dataloader lets you quickly train recommender models for TensorFlow, PyTorch and JAX. It eliminates the biggest bottleneck in training recommender models, by providing GPU optimized dataloaders that read data directly into the GPU, and then do a 0-copy transfer to TensorFlow and PyTorch using dlpack.

Merlin Dataloader的优点包括

  • 比原生框架数据加载器快10倍以上
  • 处理大于内存的数据集
  • 每个epoch的洗牌
  • 分布式训练

安装

Merlin-dataloader需要Python版本3.7+。另外,GPU支持需要CUDA 11.0+。

使用Conda安装

conda install -c nvidia -c rapidsai -c numba -c conda-forge merlin-dataloader python=3.7 cudatoolkit=11.2

从PyPi安装

pip install merlin-dataloader

还有NGC上的带有merlin-dataloader及其依赖项的Docker容器

基本用法

# Get a merlin dataset from a set of parquet files
import merlin.io
dataset = merlin.io.Dataset(PARQUET_FILE_PATHS, engine="parquet")

# Create a Tensorflow dataloader from the dataset, loading 65K items
# per batch
from merlin.dataloader.tensorflow import Loader
loader = Loader(dataset, batch_size=65536)

# Get a single batch of data. Inputs will be a dictionary of columnname
# to TensorFlow tensors
inputs, target = next(loader)

# Train a Keras model with the dataloader
model = tf.keras.Model( ... )
model.fit(loader, epochs=5)

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码分发

merlin-dataloader-23.8.0.tar.gz (46.9 kB 查看哈希值)

上传时间: 源代码

由以下支持