跳转到主要内容

多后端Keras。

项目描述

Keras 3:面向人类的深度学习

Keras 3是一个支持JAX、TensorFlow和PyTorch的多后端深度学习框架。轻松构建和训练模型,用于计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等。

  • 加速模型开发:得益于Keras的高级用户界面以及易于调试的运行时(如PyTorch或JAX急切执行),更快地交付深度学习解决方案。
  • 最先进的性能:通过选择最适合您的模型架构的后端(通常是JAX!),与其他框架相比,速度提升可达20%至350%。 在此处进行基准测试
  • 数据中心规模训练:自信地从您的笔记本电脑扩展到大型GPU或TPU集群。

加入近三百万开发者,从初创企业到全球企业,利用Keras 3的力量。

安装

使用pip安装

Keras 3在PyPI上作为keras提供。请注意,Keras 2仍然以tf-keras包的形式提供。

  1. 安装keras
pip install keras --upgrade
  1. 安装后端包。

要使用 keras,您还应该安装所选的后端:tensorflowjaxtorch。请注意,使用某些 Keras 3 功能(例如某些预处理层以及 tf.data 管道)需要 tensorflow

本地安装

最小安装

Keras 3 兼容 Linux 和 MacOS 系统。对于 Windows 用户,我们建议使用 WSL2 运行 Keras。要安装本地开发版本

  1. 安装依赖项
pip install -r requirements.txt
  1. 从根目录运行安装命令。
python pip_build.py --install
  1. 当创建更新 keras_export 公共 API 的 PR 时,运行 API 生成脚本
./shell/api_gen.sh

添加 GPU 支持

requirements.txt 文件将安装 TensorFlow、JAX 和 PyTorch 的仅 CPU 版本。对于 GPU 支持,我们还为 TensorFlow、JAX 和 PyTorch 提供了单独的 requirements-{backend}-cuda.txt 文件。这些文件通过 pip 安装所有 CUDA 依赖项,并假定已预先安装 NVIDIA 驱动程序。我们建议为每个后端创建一个干净的 Python 环境,以避免 CUDA 版本不匹配。以下是如何使用 conda 创建 Jax GPU 环境的示例

conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install

配置您的后端

您可以导出环境变量 KERAS_BACKEND,或者您可以编辑本地配置文件 ~/.keras/keras.json 来配置您的后端。可用的后端选项有:"tensorflow""jax""torch"。示例

export KERAS_BACKEND="jax"

In Colab,您可以这样做

import os
os.environ["KERAS_BACKEND"] = "jax"

import keras

注意:必须在导入 keras 之前配置后端,并且在导入包之后不能更改后端。

向后兼容性

Keras 3 的目的是作为 tf.keras 的替换品使用(当使用 TensorFlow 后端时)。只需使用现有的 tf.keras 代码,确保您的 model.save() 调用使用最新的 .keras 格式,即可完成。

如果您的 tf.keras 模型不包含自定义组件,您可以直接在 JAX 或 PyTorch 上运行它。

如果它包含自定义组件(例如自定义层或自定义 train_step()),通常只需几分钟就可以将其转换为后端无关的实现。

此外,Keras 模型可以以任何格式消耗数据集,无论您使用的是哪种后端:您可以使用现有的 tf.data.Dataset 管道或 PyTorch DataLoaders 来训练模型。

为什么使用 Keras 3?

  • 在任意框架上运行您的高级 Keras 工作流程 -- 随意享受每个框架的优势,例如 JAX 的可扩展性和性能或 TensorFlow 的生产生态系统选项。
  • 编写自定义组件(例如层、模型、度量)以供您在任何框架的低级工作流程中使用。
    • 您可以在从头开始编写的原生 TF、JAX 或 PyTorch 训练循环中训练 Keras 模型。
    • 您可以将 Keras 模型用作 PyTorch 本地 Module 的部分或 JAX 本地模型函数的一部分。
  • 通过避免框架锁定,使您的 ML 代码面向未来。
  • 作为 PyTorch 用户:终于可以访问 Keras 的强大功能和易用性了!
  • 作为 JAX 用户:可以访问一个功能全面、经过实战检验、文档齐全的建模和训练库。

更多内容请参阅 Keras 3 发布公告

项目详情


发布历史 发布通知 | RSS 源

下载文件

下载适合您平台的文件。如果您不确定该选择哪个,请了解更多关于 安装包 的信息。

源代码分发

keras_nightly-3.6.0.dev2024100403.tar.gz (885.8 kB 查看哈希值)

上传时间 源代码

构建分发

keras_nightly-3.6.0.dev2024100403-py3-none-any.whl (1.2 MB 查看哈希值)

上传时间 Python 3

支持

AWSAWS云计算和安全赞助商DatadogDatadog监控FastlyFastlyCDNGoogleGoogle下载分析MicrosoftMicrosoftPSF赞助商PingdomPingdom监控SentrySentry错误记录StatusPageStatusPage状态页面