跳转到主要内容

Gradientzoo Python绑定

项目描述

Documentation Status

这是一个用于Gradientzoo API的Python库 - 版本和分享您的训练好的神经网络模型。使用Gradientzoo轻松加载预训练的神经网络。以下是使用Tensorflow加载模型的简单方法(下面是完整的示例)

import tensorflow as tf
from gradientzoo.tensorflow import TensorflowGradientzoo

# (build MNIST graph here)

with tf.Session() as sess:
    # Load latest weights from Gradientzoo
    TensorflowGradientzoo('ericflo/mnist').load(sess)

    # Graph is now ready to use!

保存模型同样简单

import tensorflow as tf
from gradientzoo import TensorflowGradientzoo

# (build MNIST graph here)

with tf.Session() as sess:
    for epoch in xrange(6):
        # Train the model...

        # Save the updated weights out to Gradientzoo
        TensorflowGradientzoo('ericflo/mnist').save(sess)

功能

支持使用您选择的框架以Python保存Keras模型、变量在Tensorflow中,以及在Lasagne中的网络,以及使用Python的常规旧文件。

安装

除非您想修改此包,否则不需要此源代码。如果您只想使用Gradientzoo Python绑定,您应该运行

pip install –upgrade gradientzoo

或者

easy_install –upgrade gradientzoo

有关安装pip的说明,请参阅http://www.pip-installer.org/en/latest/index.html。如果您在一个有easy_install但没有pip的系统上,您可以使用easy_install。如果您不使用virtualenv,您可能需要在那些命令前加上sudo。您可以在http://www.virtualenv.org/上了解更多关于virtualenv的信息。

要从源安装,请运行

python setup.py install

文档

请参阅http://python-gradientzoo.readthedocs.org/获取最新文档,或访问项目页面查看特定于项目的说明,例如:https://www.gradientzoo.com/ericflo/mnist

设置Gradientzoo账户

https://www.gradientzoo.com/register注册Gradientzoo

贡献

支持

如果您遇到问题,请通过support@gradientzoo.com告知我们

完整Tensorflow示例

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data, mnist
from gradientzoo.tensorflow import TensorflowGradientzoo

learning_rate = 0.01
batch_size = 100

# Build MNIST graph
images_placeholder = tf.placeholder(tf.float32,
                                    shape=(batch_size, mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
logits = mnist.inference(images_placeholder, 128, 32)
loss = mnist.loss(logits, labels_placeholder)
train_op = mnist.training(loss, learning_rate)
eval_correct = mnist.evaluation(logits, labels_placeholder)

# Start a Tensorflow session
with tf.Session() as sess:
    # Load latest weights from Gradientzoo
    TensorflowGradientzoo('ericflo/mnist').load(sess)

    # Read in some data
    data_sets = input_data.read_data_sets('data', False)

    # Test the trained network on the dataset
    true_count = 0
    for step in xrange(data_sets.test.num_examples // batch_size):
        images_feed, labels_feed = data_sets.test.next_batch(batch_size, False)

        true_count += sess.run(eval_correct, feed_dict={
            images_placeholder: images_feed,
            labels_placeholder: labels_feed,
        })

    precision = true_count / float(data_sets.test.num_examples)
    print('Num Examples: %d  Num Correct: %d  Precision: %0.04f' %
          (data_sets.test.num_examples, true_count, precision))

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码分布

gradientzoo-0.8.8.tar.gz (7.5 kB 查看哈希值)

上传时间: 源代码

支持