使用Keras Sequence懒加载混合序列,专注于多任务模型。
项目描述
使用Keras Sequence懒加载混合序列,专注于多任务模型。
我该如何安装这个包?
像往常一样,只需使用pip下载
pip install keras_mixed_sequence
使用示例
传统单任务模型的示例
首先,让我们创建一个简单的单任务模型
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
model = Sequential([
Dense(1, activation="relu")
])
model.compile(
optimizer="nadam",
loss="relu"
)
然后,我们继续加载或创建训练数据。在这里,将列出一些用于与该库一起使用的自定义Sequence对象。
X = either_a_numpy_array_or_sequence_for_input
y = either_a_numpy_array_or_sequence_for_output
现在我们使用MixedSequence对象组合训练数据。
from keras_mixed_sequence import MixedSequence
sequence = MixedSequence(
X, y,
batch_size=batch_size
)
最后,我们可以训练模型
from multiprocessing import cpu_count
model.fit_generator(
sequence,
steps_per_epoch=sequence.steps_per_epoch,
epochs=2,
verbose=0,
use_multiprocessing=True,
workers=cpu_count(),
shuffle=True
)
多任务模型的示例
首先,让我们创建一个简单的多任务模型
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
inputs = Input(shape=(10,))
output1 = Dense(
units=10,
activation="relu",
name="output1"
)(inputs)
output2 = Dense(
units=10,
activation="relu",
name="output2"
)(inputs)
model = Model(
inputs=inputs,
outputs=[output1, output2],
name="my_model"
)
model.compile(
optimizer="nadam",
loss="MSE"
)
然后,我们继续加载或创建训练数据。在这里,将列出一些用于与该库一起使用的自定义Sequence对象。
X = either_a_numpy_array_or_sequence_for_input
y1 = either_a_numpy_array_or_sequence_for_output1
y2 = either_a_numpy_array_or_sequence_for_output2
现在我们使用MixedSequence对象组合训练数据。
from keras_mixed_sequence import MixedSequence
sequence = MixedSequence(
x=X,
y={
"output1": y1,
"output2": y2
},
batch_size=batch_size
)
最后,我们可以训练模型
from multiprocessing import cpu_count
model.fit_generator(
sequence,
steps_per_epoch=sequence.steps_per_epoch,
epochs=2,
verbose=0,
use_multiprocessing=True,
workers=cpu_count(),
shuffle=True
)
项目详情
关闭
keras_mixed_sequence-1.0.29.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 4d49f4325988dccd9d1f6d1ea53e1fc0be8a4479545387f79db7d9245ed66e91 |
|
MD5 | f284371a3bb9e306af1e5bed60dbd9d9 |
|
BLAKE2b-256 | 26970f66d9fa1579eaded44a0c572c8e2e9d2daa89f5d911d45f0307e8d2be73 |