跳转到主要内容

使用Keras Sequence懒加载混合序列,专注于多任务模型。

项目描述

Pypi project Pypi total project downloads

使用Keras Sequence懒加载混合序列,专注于多任务模型。

我该如何安装这个包?

像往常一样,只需使用pip下载

pip install keras_mixed_sequence

使用示例

传统单任务模型的示例

首先,让我们创建一个简单的单任务模型

from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

model = Sequential([
    Dense(1, activation="relu")
])
model.compile(
    optimizer="nadam",
    loss="relu"
)

然后,我们继续加载或创建训练数据。在这里,将列出一些用于与该库一起使用的自定义Sequence对象。

X = either_a_numpy_array_or_sequence_for_input
y = either_a_numpy_array_or_sequence_for_output

现在我们使用MixedSequence对象组合训练数据。

from keras_mixed_sequence import MixedSequence

sequence = MixedSequence(
    X, y,
    batch_size=batch_size
)

最后,我们可以训练模型

from multiprocessing import cpu_count

model.fit_generator(
    sequence,
    steps_per_epoch=sequence.steps_per_epoch,
    epochs=2,
    verbose=0,
    use_multiprocessing=True,
    workers=cpu_count(),
    shuffle=True
)

多任务模型的示例

首先,让我们创建一个简单的多任务模型

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input

inputs = Input(shape=(10,))

output1 = Dense(
    units=10,
    activation="relu",
    name="output1"
)(inputs)
output2 = Dense(
    units=10,
    activation="relu",
    name="output2"
)(inputs)

model = Model(
    inputs=inputs,
    outputs=[output1, output2],
    name="my_model"
)

model.compile(
    optimizer="nadam",
    loss="MSE"
)

然后,我们继续加载或创建训练数据。在这里,将列出一些用于与该库一起使用的自定义Sequence对象。

X = either_a_numpy_array_or_sequence_for_input
y1 = either_a_numpy_array_or_sequence_for_output1
y2 = either_a_numpy_array_or_sequence_for_output2

现在我们使用MixedSequence对象组合训练数据。

from keras_mixed_sequence import MixedSequence

sequence = MixedSequence(
    x=X,
    y={
        "output1": y1,
        "output2": y2
    },
    batch_size=batch_size
)

最后,我们可以训练模型

from multiprocessing import cpu_count

model.fit_generator(
    sequence,
    steps_per_epoch=sequence.steps_per_epoch,
    epochs=2,
    verbose=0,
    use_multiprocessing=True,
    workers=cpu_count(),
    shuffle=True
)

项目详情


下载文件

下载适用于您的平台文件。如果您不确定要选择哪个,请了解更多关于 安装包 的信息。

源分发

keras_mixed_sequence-1.0.29.tar.gz (7.1 kB 查看哈希值)

上传时间:

支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面