跳转到主要内容

Replicate的Python客户端

项目描述

Replicate Python客户端

这是Replicate的Python客户端。它允许您从Python代码或Jupyter笔记本中运行模型,并在Replicate上执行各种其他操作。

👋Google Colab上查看此教程的交互式版本。

Open In Colab

需求

  • Python 3.8+

安装

pip install replicate

认证

在运行任何使用API的Python脚本之前,您需要将您的Replicate API令牌设置到您的环境中。

replicate.com/account 获取您的令牌并将其设置为环境变量

export REPLICATE_API_TOKEN=<your token>

我们建议不要直接将令牌添加到源代码中,因为您不希望将凭据放入源代码控制。如果有人使用您的API密钥,他们的使用将被计费到您的账户。

运行模型

创建一个新的Python文件,并添加以下代码,用您自己的模型标识符和输入替换

>>> import replicate
>>> replicate.run(
        "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
        input={"prompt": "a 19th century portrait of a wombat gentleman"}
    )

['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png']

[!TIP] 您也可以通过在方法名前添加 async_ 来异步使用Replicate客户端。

以下是一个同时运行多个预测并等待它们全部完成的示例

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

要运行需要文件输入的模型,您可以传递指向互联网上公开可访问文件的URL或指向本地设备上文件的句柄。

>>> output = replicate.run(
        "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
        input={ "image": open("path/to/mystery.jpg") }
    )

"an astronaut riding a horse"

replicate.run 在预测失败时会引发 ModelError。您可以通过访问异常的 prediction 属性来获取有关失败更详细的信息。

import replicate
from replicate.exceptions import ModelError

try:
  output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as e
  if "(some known issue)" in e.prediction.logs:
    pass

  print("Failed prediction: " + e.prediction.id)

运行模型并流式传输其输出

Replicate的API支持服务器发送事件流(SSE)。使用 stream 方法在模型生成令牌时消耗令牌。

import replicate

for event in replicate.stream(
    "meta/meta-llama-3-70b-instruct",
    input={
        "prompt": "Please write a haiku about llamas.",
    },
):
    print(str(event), end="")

[!TIP] 一些模型,如 meta/meta-llama-3-70b-instruct,不需要版本字符串。您始终可以参考模型页面的API文档以获取详细信息。

您还可以流式传输您创建的预测的输出。当您希望预测ID与其输出分离时,这很有帮助。

prediction = replicate.predictions.create(
    model="meta/meta-llama-3-70b-instruct"
    input={"prompt": "Please write a haiku about llamas."},
    stream=True,
)

for event in prediction.stream():
    print(str(event), end="")

有关更多信息,请参阅Replicate文档中的“流式传输输出”

在后台运行模型

您可以启动模型并在后台运行它

>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
    version=version,
    input={"prompt":"Watercolor painting of an underwater submarine"})

>>> prediction
Prediction(...)

>>> prediction.status
'starting'

>>> dict(prediction)
{"id": "...", "status": "starting", ...}

>>> prediction.reload()
>>> prediction.status
'processing'

>>> print(prediction.logs)
iteration: 0, render:loss: -0.6171875
iteration: 10, render:loss: -0.92236328125
iteration: 20, render:loss: -1.197265625
iteration: 30, render:loss: -1.3994140625

>>> prediction.wait()

>>> prediction.status
'succeeded'

>>> prediction.output
'https://.../output.png'

在后台运行模型并获取webhook

您可以在模型完成后运行模型并获取webhook,而不是等待其完成

model = replicate.models.get("ai-forever/kandinsky-2.2")
version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463")
prediction = replicate.predictions.create(
    version=version,
    input={"prompt":"Watercolor painting of an underwater submarine"},
    webhook="https://example.com/your-webhook",
    webhook_events_filter=["completed"]
)

有关接收webhook的详细信息,请参阅replicate.com/docs/webhooks

将模型组合成管道

您可以将一个模型的输出作为另一个模型的输入运行

laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
image = laionide.predict(prompt="avocado armchair")
upscaled_image = swinir.predict(image=image)

从正在运行的模型获取输出

在模型运行时运行模型并获取其输出

iterator = replicate.run(
    "pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf",
    input={"prompts": "san francisco sunset"}
)

for image in iterator:
    display(image)

取消预测

您可以取消正在运行的预测

>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
        version=version,
        input={"prompt":"Watercolor painting of an underwater submarine"}
    )

>>> prediction.status
'starting'

>>> prediction.cancel()

>>> prediction.reload()
>>> prediction.status
'canceled'

列出预测

您可以列出您已运行的所有预测

replicate.predictions.list()
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]

预测列表是分页的。您可以通过将 next 属性作为参数传递给 list 方法来获取下一页的预测

page1 = replicate.predictions.list()

if page1.next:
    page2 = replicate.predictions.list(page1.next)

加载输出文件

输出文件作为HTTPS URL返回。您可以将输出文件作为缓冲区加载

import replicate
from PIL import Image
from urllib.request import urlretrieve

out = replicate.run(
    "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
    input={"prompt": "wavy colorful abstract patterns, oceans"}
    )

urlretrieve(out[0], "/tmp/out.png")
background = Image.open("/tmp/out.png")

列出模型

您可以列出您创建的模型

replicate.models.list()

模型列表是分页的。您可以通过将 next 属性作为参数传递给 list 方法来获取下一页的模型,或者您可以使用 paginate 方法自动获取页面。

# Automatic pagination using `replicate.paginate` (recommended)
models = []
for page in replicate.paginate(replicate.models.list):
    models.extend(page.results)
    if len(models) > 100:
        break

# Manual pagination using `next` cursors
page = replicate.models.list()
while page:
    models.extend(page.results)
    if len(models) > 100:
          break
    page = replicate.models.list(page.next) if page.next else None

您还可以在Replicate上找到特色模型的集合

>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)"

>>> replicate.collections.get("text-to-image").models
[<Model: stability-ai/sdxl>, ...]

创建模型

您可以为具有给定名称、可见性和硬件SKU的用户或组织创建模型

import replicate

model = replicate.models.create(
    owner="your-username",
    name="my-model",
    visibility="public",
    hardware="gpu-a40-large"
)

以下是列出在Replicate上运行模型的所有可用硬件的方法

>>> [hw.sku for hw in replicate.hardware.list()]
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']

微调模型

使用训练API微调模型以使其在特定任务上表现更好。要查看当前支持微调的 语言模型,请查看Replicate的可训练语言模型集合

如果您想微调 图像模型,请参阅Replicate的图像模型微调指南

以下是微调Replicate上模型的方法

training = replicate.trainings.create(
    model="stability-ai/sdxl",
    version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
      "input_images": "https://my-domain/training-images.zip",
      "token_string": "TOK",
      "caption_prefix": "a photo of TOK",
      "max_train_steps": 1000,
      "use_face_detection_instead": False
    },
    # You need to create a model on Replicate that will be the destination for the trained version.
    destination="your-username/model-name"
)

自定义客户端行为

《replicate》包导出默认共享客户端。此客户端使用由 REPLICATE_API_TOKEN 环境变量设置的 API 令牌初始化。

您可以创建自己的客户端实例,传入不同的 API 令牌值,添加自定义请求头,或控制底层 HTTPX 客户端 的行为。

import os
from replicate.client import Client

replicate = Client(
  api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"]
  headers={
    "User-Agent": "my-app/1.0"
  }
)

[!警告] 不要将认证凭据(如 API 令牌)硬编码到您的代码中。相反,在运行程序时,将它们作为环境变量传递。

开发

请参阅 CONTRIBUTING.md

项目详情


发布历史 发布通知 | RSS 源

下载文件

下载适合您平台文件。如果您不确定选择哪个,请了解有关 安装软件包 的更多信息。

源分发

replicate-0.34.2.tar.gz (55.8 kB 查看哈希值)

上传时间

构建分发

replicate-0.34.2-py3-none-any.whl (45.9 kB 查看哈希值)

上传于 Python 3