研究生产力的构建模块。
项目描述
🔥 Elements
研究生产力的构建模块。
安装
pip install elements
功能
Elements旨在为研究代码中的常见问题提供经过深思熟虑的解决方案。它也是可黑客攻击的。如果您需要更改一些代码,我们鼓励您将相应的文件叉到您的项目目录中并进行编辑。
elements.Logger
一种可扩展后端的数组类型记录器。指标在后台线程中写入,以避免阻塞程序执行,这在桶访问缓慢的云服务中尤为重要。
提供的后端
TerminalOutput(pattern)
:将标量打印到终端。可以通过正则表达式过滤到更少的指标。JSONLOutput(logdir, filename, pattern)
:将标量写入JSONL文件。例如,可以直接使用pandas读取。TensorBoardOutput(logdir)
:标量、直方图、图像、GIF。当当前文件超过大小限制时,自动启动新的事件文件以支持云存储,其中附加到文件需要完整下载和重新上传。WandBOutput(pattern, **init_kwargs)
:字符串、直方图、图像、视频。MLFlowOutput(run_name, resume_id)
:将所有类型的指标记录到MLFlow。
step = elements.Counter()
logger = elements.Logger(step, [
elements.logger.TerminalOutput(),
elements.logger.JSONLOutput(logdir, 'metrics.jsonl'),
elements.logger.TensorBoardOutput(logdir),
elements.logger.WandBOutput(name='name', project='project'),
])
step.increment()
logger.scalar('foo', 42)
logger.scalar('foo', 43)
logger.scalar('foo', 44)
logger.vector('vector', np.zeros(100))
logger.image('image', np.zeros((800, 600, 3, np.uint8)))
logger.video('video', np.zeros((100, 64, 64, 3, np.uint8)))
logger.add({'foo': 42, 'vector': np.zeros(100)}, prefix='scope')
logger.write()
elements.Config
一个不可变的嵌套目录,用于保存配置。键可以通过属性语法访问。值限于JSON支持的原始类型。在配置中替换值时检查类型。
config = elements.Config(
logdir='path/to/dir',
foo=dict(bar=42),
)
print(config) # Pretty printing
print(config.foo.bar) # Attribute syntax
print(config['foo']['bar']) # Dictionary syntax
config.logdir = 'path/to/new/dir' # Not allowed
# Access a copy of the flattened dictionary underlying the config.
config.flat == {'logdir': 'path/to/dir', 'foo.bar': 42}
# Configs are immutable, so updating them returns a new object.
new_config = config.update({'foo.bar': 43})
# Types are enforced when updating configs, but values of other types are
# allowed as long as they can be converted without loss of information.
new_config = config.update({'foo.bar': float(1e5)}) # Allowed
new_config = config.update({'foo.bar': float(1.5)}) # Not allowed
# Configs can be saved and loaded in JSON and YAML formats.
config.save('config.json')
config = elements.Config.load('config.json')
elements.Flags
一个类似于argparse
的命令行标志解析器,但使用更快且更灵活。强制类型并支持嵌套字典和一次通过正则表达式覆盖多个标志。
必须提供有效键及其默认值的映射以推断类型。因为所有值都有默认值,所以没有用户必须在命令行上指定的必填参数。
# Create flags parser from default values.
flags = elements.Flags(logdir='path/to/dir', bar=42)
# Create flags parser from config.
flags = elements.Flags(elements.Config({
'logdir': 'path/to/dir',
'foo.bar': 42,
'list': [1, 2, 3],
}))
# Load a config from YAML and overwrite it from it from the command line.
config = elements.Config.load('defaults.yaml')
config = elements.Flags(config).parse()
# Overwrite some of the keys.
config = flags.parse(['--logdir', 'path/to/new/dir', '--foo.bar', '43'])
# Supports syntax with space or equals.
config = flags.parse(['--logdir=path/to/new/dir'])
# Overwrite lists.
config = flags.parse(['--list', '10', '20', '30'])
config = flags.parse(['--list', '10,20,30'])
config = flags.parse(['--list=10,20,30'])
# Set all nested keys that end in 'bar'.
config = flags.parse(['--.*\.bar$', '43'])
# Parse only known flags.
config, remaining = flags.parse_known(['--logdir', 'dir', '--other', '123'])
remaining == ['--other', '123']
# Print a help page and terminate the program.
flags.parse(['--help'])
# Print a help page without terminating the program.
flags = elements.Flags(logdir='path/to/dir', bar=42, help_exits=False)
parsed, remaining = flags.parse_known(['--help', '--other=value'])
remaining == ['--help', '--other=value']
second_parser.parse(remaining) # Now we exit.
elements.Path
类似于pathlib
的文件系统抽象,可扩展到新的文件系统。包含对本地文件系统和GCS存储桶的支持。
path = elements.Path('gs://bucket/path/to/file.txt')
# String operations
path.parent # gs://bucket/path/to
path.name # file.txt
path.stem # file
path.suffix # .txt
# File operations
path.read(mode='r') # Content of the file as string
path.read(mode='rb') # Content of the file as bytes
path.write(content, mode='w') # Write string to the file
path.write(content, mode='wb') # Write bytes to the file
with path.open(mode='r') as f: # Create a file pointer
pass
# File system checks
path.parent.glob('*') # Get all sibling paths
path.exists() # True
path.isdir() # False
path.isfile() # True
# File system changes
(path.parent / 'foo').mkdir() # Creates directory including parents
path.remove() # Deletes a file or empty directory
path.parent.rmtree() # Deletes directory and its content
path.copy(path.parent / 'copy.txt') # Makes a copy
path.move(path.parent / 'moved.txt') # Moves the file
elements.Checkpoint
包含可以保存到磁盘和从磁盘加载的对象集合。
每个附加到检查点的对象都需要实现 save() -> data
和 load(data)
方法,其中 data
是可序列化的。
检查点是在后台线程中编写的,以避免阻塞程序执行。新的检查点首先写入临时路径,然后一旦完全写入就重命名为实际路径,这样即使程序在写入过程中被终止,路径也始终指向有效的名称。
step = elements.Counter()
cp = elements.Checkpoint(directory)
# Attach objects to the checkpoint.
cp.step = step
cp.model = model
# After attaching the objects we load the checkpoint from disk if it exists
# and otherwise save an initial checkpoint.
cp.load_or_save()
# Later on, we can change the objects and then save the checkpoint again.
should_save = elements.when.Every(10)
for _ in range(100):
step.increment()
if should_save(step):
cp.save()
# We can also load checkpoints or parts of a checkpoint from a different directory.
cp.load(pretraining_directory, keys=['model'])
print(cp.model)
elements.Timer
收集程序不同部分的运行时统计信息。测量代码段内的代码,可以将方法包装到段中。返回执行次数、执行时间、程序总时间的比例等。可以将生成的统计信息添加到日志记录器。
timer = Timer()
timer.section('foo'):
time.sleep(10)
timer.wrap('name', obj, ['method1', 'method2'])
obj.method1()
obj.method1()
obj.method1()
obj.method2()
stats = timer.stats(reset=True, log=True)
stats == {
'foo_count': 1,
'foo_total': 10.0,
'foo_avg': 10.0,
'foo_min': 10.0,
'foo_max': 10.0,
'foo_frac': 0.92,
'name.method1_count': 3,
'name.method1_frac': 0.07,
# ...
'name.method2_frac': 0.01,
# ...
}
elements.when
运行代码的辅助工具,例如每隔固定步骤、秒或时间的某个比例。计数是健壮的,所以当你跳过一个步骤时,它将在下一次运行代码以赶上。
should = elements.when.Every(100)
for step in range(1000):
if should(step):
print(step) # 0, 100, 200, ...
should = elements.when.Ratio(0.3333)
for step in range(100):
if should(step):
print(step) # 0, 4, 7, 10, 13, 16, ...
should = elements.when.Once()
for step in range(100):
if should(step):
print(step) # 0
should = elements.when.Until(5)
for step in range(10):
if should(step):
print(step) # 0, 1, 2, 3, 4
should = elements.when.Clock(1)
for step in range(100):
if should(step):
print(step) # 0, 10, 20, 30, ...
time.sleep(0.1)
elements.plotting
具有合理默认值的存储、加载和绘图数据工具。
数据以 run
格式存储在 gzip 压缩的 JSON 文件中。每个文件包含一个或多个运行列表。运行是一个字典,包含键 task
、method
、seed
、xs
、ys
。任务、方法和种子是字符串字段,而 xs 和 ys 是包含要绘制的数据的相等长度的数字列表。
查看存储库中的 plotting.py
以查看所有可用函数的列表,而不仅仅是此代码段中使用的函数。
from elements import plotting
runs = plotting.load('filename.json.gz')
plotting.dump(runs, 'filename.json.gz')
bins = np.linspace(0, 1e6, 100)
tensor, tasks, methods, seeds = plotting.tensor(runs, bins)
tensor.shape == (len(tasks), len(methods), len(seeds), len(bins))
fig, axes = plotting.plots(len(tasks))
for i, task in enumerate(tasks):
ax = axes[i]
for j, method in enumerate(methods):
# Aggregate over seeds.
mean = np.nanmean(tensor[i, j, :, :], 2)
std = np.nanstd(tensor[i, j, :, :], 2)
plotting.curve(ax, bins[1:], mean, std, label=method, order=j)
plotting.legend(fig, adjust=True)
# Saves the figure in both PNG and PDF formats and attempts to crop margins off
# the PDF.
plotting.save(fig, 'path/to/name')
问题
请在 Github 上提交 问题。
项目详情
elements-3.15.13.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | ad530f0c6649032bdf7af9f86e66fa91792c81271c5264b4d6df321cb9015549 |
|
MD5 | a716210e0f66f6c67d6699f50f467cb0 |
|
BLAKE2b-256 | 3cf063b0ed0badd8f6d7c8a6ba5e5e2ef3425a5c3ac4bba0a7c5eec896decbf1 |