跳转到主要内容

研究生产力的构建模块。

项目描述

PyPI

🔥 Elements

研究生产力的构建模块。

安装

pip install elements

功能

Elements旨在为研究代码中的常见问题提供经过深思熟虑的解决方案。它也是可黑客攻击的。如果您需要更改一些代码,我们鼓励您将相应的文件叉到您的项目目录中并进行编辑。

elements.Logger

一种可扩展后端的数组类型记录器。指标在后台线程中写入,以避免阻塞程序执行,这在桶访问缓慢的云服务中尤为重要。

提供的后端

  • TerminalOutput(pattern):将标量打印到终端。可以通过正则表达式过滤到更少的指标。
  • JSONLOutput(logdir, filename, pattern):将标量写入JSONL文件。例如,可以直接使用pandas读取。
  • TensorBoardOutput(logdir):标量、直方图、图像、GIF。当当前文件超过大小限制时,自动启动新的事件文件以支持云存储,其中附加到文件需要完整下载和重新上传。
  • WandBOutput(pattern, **init_kwargs):字符串、直方图、图像、视频。
  • MLFlowOutput(run_name, resume_id):将所有类型的指标记录到MLFlow。
step = elements.Counter()
logger = elements.Logger(step, [
    elements.logger.TerminalOutput(),
    elements.logger.JSONLOutput(logdir, 'metrics.jsonl'),
    elements.logger.TensorBoardOutput(logdir),
    elements.logger.WandBOutput(name='name', project='project'),
])

step.increment()
logger.scalar('foo', 42)
logger.scalar('foo', 43)
logger.scalar('foo', 44)
logger.vector('vector', np.zeros(100))
logger.image('image', np.zeros((800, 600, 3, np.uint8)))
logger.video('video', np.zeros((100, 64, 64, 3, np.uint8)))

logger.add({'foo': 42, 'vector': np.zeros(100)}, prefix='scope')

logger.write()

elements.Config

一个不可变的嵌套目录,用于保存配置。键可以通过属性语法访问。值限于JSON支持的原始类型。在配置中替换值时检查类型。

config = elements.Config(
    logdir='path/to/dir',
    foo=dict(bar=42),
)

print(config)                      # Pretty printing
print(config.foo.bar)              # Attribute syntax
print(config['foo']['bar'])        # Dictionary syntax
config.logdir = 'path/to/new/dir'  # Not allowed

# Access a copy of the flattened dictionary underlying the config.
config.flat == {'logdir': 'path/to/dir', 'foo.bar': 42}

# Configs are immutable, so updating them returns a new object.
new_config = config.update({'foo.bar': 43})

# Types are enforced when updating configs, but values of other types are
# allowed as long as they can be converted without loss of information.
new_config = config.update({'foo.bar': float(1e5)})  # Allowed
new_config = config.update({'foo.bar': float(1.5)})  # Not allowed

# Configs can be saved and loaded in JSON and YAML formats.
config.save('config.json')
config = elements.Config.load('config.json')

elements.Flags

一个类似于argparse的命令行标志解析器,但使用更快且更灵活。强制类型并支持嵌套字典和一次通过正则表达式覆盖多个标志。

必须提供有效键及其默认值的映射以推断类型。因为所有值都有默认值,所以没有用户必须在命令行上指定的必填参数。

# Create flags parser from default values.
flags = elements.Flags(logdir='path/to/dir', bar=42)

# Create flags parser from config.
flags = elements.Flags(elements.Config({
    'logdir': 'path/to/dir',
    'foo.bar': 42,
    'list': [1, 2, 3],
}))

# Load a config from YAML and overwrite it from it from the command line.
config = elements.Config.load('defaults.yaml')
config = elements.Flags(config).parse()

# Overwrite some of the keys.
config = flags.parse(['--logdir', 'path/to/new/dir', '--foo.bar', '43'])

# Supports syntax with space or equals.
config = flags.parse(['--logdir=path/to/new/dir'])

# Overwrite lists.
config = flags.parse(['--list', '10', '20', '30'])
config = flags.parse(['--list', '10,20,30'])
config = flags.parse(['--list=10,20,30'])

# Set all nested keys that end in 'bar'.
config = flags.parse(['--.*\.bar$', '43'])

# Parse only known flags.
config, remaining = flags.parse_known(['--logdir', 'dir', '--other', '123'])
remaining == ['--other', '123']

# Print a help page and terminate the program.
flags.parse(['--help'])

# Print a help page without terminating the program.
flags = elements.Flags(logdir='path/to/dir', bar=42, help_exits=False)
parsed, remaining = flags.parse_known(['--help', '--other=value'])
remaining == ['--help', '--other=value']
second_parser.parse(remaining)  # Now we exit.

elements.Path

类似于pathlib的文件系统抽象,可扩展到新的文件系统。包含对本地文件系统和GCS存储桶的支持。

path = elements.Path('gs://bucket/path/to/file.txt')

# String operations
path.parent                           # gs://bucket/path/to
path.name                             # file.txt
path.stem                             # file
path.suffix                           # .txt

# File operations
path.read(mode='r')                   # Content of the file as string
path.read(mode='rb')                  # Content of the file as bytes
path.write(content, mode='w')         # Write string to the file
path.write(content, mode='wb')        # Write bytes to the file
with path.open(mode='r') as f:        # Create a file pointer
  pass

# File system checks
path.parent.glob('*')                 # Get all sibling paths
path.exists()                         # True
path.isdir()                          # False
path.isfile()                         # True

# File system changes
(path.parent / 'foo').mkdir()         # Creates directory including parents
path.remove()                         # Deletes a file or empty directory
path.parent.rmtree()                  # Deletes directory and its content
path.copy(path.parent / 'copy.txt')   # Makes a copy
path.move(path.parent / 'moved.txt')  # Moves the file

elements.Checkpoint

包含可以保存到磁盘和从磁盘加载的对象集合。

每个附加到检查点的对象都需要实现 save() -> dataload(data) 方法,其中 data 是可序列化的。

检查点是在后台线程中编写的,以避免阻塞程序执行。新的检查点首先写入临时路径,然后一旦完全写入就重命名为实际路径,这样即使程序在写入过程中被终止,路径也始终指向有效的名称。

step = elements.Counter()

cp = elements.Checkpoint(directory)
# Attach objects to the checkpoint.
cp.step = step
cp.model = model
# After attaching the objects we load the checkpoint from disk if it exists
# and otherwise save an initial checkpoint.
cp.load_or_save()

# Later on, we can change the objects and then save the checkpoint again.
should_save = elements.when.Every(10)
for _ in range(100):
  step.increment()
  if should_save(step):
    cp.save()

# We can also load checkpoints or parts of a checkpoint from a different directory.
cp.load(pretraining_directory, keys=['model'])
print(cp.model)

elements.Timer

收集程序不同部分的运行时统计信息。测量代码段内的代码,可以将方法包装到段中。返回执行次数、执行时间、程序总时间的比例等。可以将生成的统计信息添加到日志记录器。

timer = Timer()

timer.section('foo'):
  time.sleep(10)

timer.wrap('name', obj, ['method1', 'method2'])
obj.method1()
obj.method1()
obj.method1()
obj.method2()

stats = timer.stats(reset=True, log=True)
stats == {
    'foo_count': 1,
    'foo_total': 10.0,
    'foo_avg': 10.0,
    'foo_min': 10.0,
    'foo_max': 10.0,
    'foo_frac': 0.92,
    'name.method1_count': 3,
    'name.method1_frac': 0.07,
    # ...
    'name.method2_frac': 0.01,
    # ...
}

elements.when

运行代码的辅助工具,例如每隔固定步骤、秒或时间的某个比例。计数是健壮的,所以当你跳过一个步骤时,它将在下一次运行代码以赶上。

should = elements.when.Every(100)
for step in range(1000):
  if should(step):
    print(step)  # 0, 100, 200, ...

should = elements.when.Ratio(0.3333)
for step in range(100):
  if should(step):
    print(step)  # 0, 4, 7, 10, 13, 16, ...

should = elements.when.Once()
for step in range(100):
  if should(step):
    print(step)  # 0

should = elements.when.Until(5)
for step in range(10):
  if should(step):
    print(step)  # 0, 1, 2, 3, 4

should = elements.when.Clock(1)
for step in range(100):
  if should(step):
    print(step)  # 0, 10, 20, 30, ...
  time.sleep(0.1)

elements.plotting

具有合理默认值的存储、加载和绘图数据工具。

数据以 run 格式存储在 gzip 压缩的 JSON 文件中。每个文件包含一个或多个运行列表。运行是一个字典,包含键 taskmethodseedxsys。任务、方法和种子是字符串字段,而 xs 和 ys 是包含要绘制的数据的相等长度的数字列表。

查看存储库中的 plotting.py 以查看所有可用函数的列表,而不仅仅是此代码段中使用的函数。

from elements import plotting

runs = plotting.load('filename.json.gz')
plotting.dump(runs, 'filename.json.gz')

bins = np.linspace(0, 1e6, 100)
tensor, tasks, methods, seeds = plotting.tensor(runs, bins)
tensor.shape == (len(tasks), len(methods), len(seeds), len(bins))

fig, axes = plotting.plots(len(tasks))

for i, task in enumerate(tasks):
  ax = axes[i]
  for j, method in enumerate(methods):
    # Aggregate over seeds.
    mean = np.nanmean(tensor[i, j, :, :], 2)
    std = np.nanstd(tensor[i, j, :, :], 2)
    plotting.curve(ax, bins[1:], mean, std, label=method, order=j)

plotting.legend(fig, adjust=True)

# Saves the figure in both PNG and PDF formats and attempts to crop margins off
# the PDF.
plotting.save(fig, 'path/to/name')

问题

请在 Github 上提交 问题

项目详情


发布历史 发布通知 | RSS 源

下载文件

下载适合您的平台的文件。如果您不确定要选择哪个,请了解更多关于 安装包 的信息。

源分布

elements-3.15.13.tar.gz (36.6 kB 查看哈希值)

上传时间

由以下支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面