Replicate的Python客户端
项目描述
Replicate Python客户端
这是Replicate的Python客户端。它允许您从Python代码或Jupyter笔记本中运行模型,并在Replicate上执行各种其他操作。
👋 在Google Colab上查看此教程的交互式版本。
需求
- Python 3.8+
安装
pip install replicate
认证
在运行任何使用API的Python脚本之前,您需要将您的Replicate API令牌设置到您的环境中。
从 replicate.com/account 获取您的令牌并将其设置为环境变量
export REPLICATE_API_TOKEN=<your token>
我们建议不要直接将令牌添加到源代码中,因为您不希望将凭据放入源代码控制。如果有人使用您的API密钥,他们的使用将被计费到您的账户。
运行模型
创建一个新的Python文件,并添加以下代码,用您自己的模型标识符和输入替换
>>> import replicate
>>> replicate.run(
"stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
input={"prompt": "a 19th century portrait of a wombat gentleman"}
)
['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png']
[!TIP] 您也可以通过在方法名前添加
async_
来异步使用Replicate客户端。以下是一个同时运行多个预测并等待它们全部完成的示例
import asyncio import replicate # https://replicate.com/stability-ai/sdxl model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" prompts = [ f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"] ] async with asyncio.TaskGroup() as tg: tasks = [ tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) for prompt in prompts ] results = await asyncio.gather(*tasks) print(results)
要运行需要文件输入的模型,您可以传递指向互联网上公开可访问文件的URL或指向本地设备上文件的句柄。
>>> output = replicate.run(
"andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
input={ "image": open("path/to/mystery.jpg") }
)
"an astronaut riding a horse"
replicate.run
在预测失败时会引发 ModelError
。您可以通过访问异常的 prediction
属性来获取有关失败更详细的信息。
import replicate
from replicate.exceptions import ModelError
try:
output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as e
if "(some known issue)" in e.prediction.logs:
pass
print("Failed prediction: " + e.prediction.id)
运行模型并流式传输其输出
Replicate的API支持服务器发送事件流(SSE)。使用 stream
方法在模型生成令牌时消耗令牌。
import replicate
for event in replicate.stream(
"meta/meta-llama-3-70b-instruct",
input={
"prompt": "Please write a haiku about llamas.",
},
):
print(str(event), end="")
[!TIP] 一些模型,如 meta/meta-llama-3-70b-instruct,不需要版本字符串。您始终可以参考模型页面的API文档以获取详细信息。
您还可以流式传输您创建的预测的输出。当您希望预测ID与其输出分离时,这很有帮助。
prediction = replicate.predictions.create(
model="meta/meta-llama-3-70b-instruct"
input={"prompt": "Please write a haiku about llamas."},
stream=True,
)
for event in prediction.stream():
print(str(event), end="")
有关更多信息,请参阅Replicate文档中的“流式传输输出”。
在后台运行模型
您可以启动模型并在后台运行它
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
version=version,
input={"prompt":"Watercolor painting of an underwater submarine"})
>>> prediction
Prediction(...)
>>> prediction.status
'starting'
>>> dict(prediction)
{"id": "...", "status": "starting", ...}
>>> prediction.reload()
>>> prediction.status
'processing'
>>> print(prediction.logs)
iteration: 0, render:loss: -0.6171875
iteration: 10, render:loss: -0.92236328125
iteration: 20, render:loss: -1.197265625
iteration: 30, render:loss: -1.3994140625
>>> prediction.wait()
>>> prediction.status
'succeeded'
>>> prediction.output
'https://.../output.png'
在后台运行模型并获取webhook
您可以在模型完成后运行模型并获取webhook,而不是等待其完成
model = replicate.models.get("ai-forever/kandinsky-2.2")
version = model.versions.get("ea1addaab376f4dc227f5368bbd8eff901820fd1cc14ed8cad63b29249e9d463")
prediction = replicate.predictions.create(
version=version,
input={"prompt":"Watercolor painting of an underwater submarine"},
webhook="https://example.com/your-webhook",
webhook_events_filter=["completed"]
)
有关接收webhook的详细信息,请参阅replicate.com/docs/webhooks。
将模型组合成管道
您可以将一个模型的输出作为另一个模型的输入运行
laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
image = laionide.predict(prompt="avocado armchair")
upscaled_image = swinir.predict(image=image)
从正在运行的模型获取输出
在模型运行时运行模型并获取其输出
iterator = replicate.run(
"pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf",
input={"prompts": "san francisco sunset"}
)
for image in iterator:
display(image)
取消预测
您可以取消正在运行的预测
>>> model = replicate.models.get("kvfrans/clipdraw")
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
>>> prediction = replicate.predictions.create(
version=version,
input={"prompt":"Watercolor painting of an underwater submarine"}
)
>>> prediction.status
'starting'
>>> prediction.cancel()
>>> prediction.reload()
>>> prediction.status
'canceled'
列出预测
您可以列出您已运行的所有预测
replicate.predictions.list()
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]
预测列表是分页的。您可以通过将 next
属性作为参数传递给 list
方法来获取下一页的预测
page1 = replicate.predictions.list()
if page1.next:
page2 = replicate.predictions.list(page1.next)
加载输出文件
输出文件作为HTTPS URL返回。您可以将输出文件作为缓冲区加载
import replicate
from PIL import Image
from urllib.request import urlretrieve
out = replicate.run(
"stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
input={"prompt": "wavy colorful abstract patterns, oceans"}
)
urlretrieve(out[0], "/tmp/out.png")
background = Image.open("/tmp/out.png")
列出模型
您可以列出您创建的模型
replicate.models.list()
模型列表是分页的。您可以通过将 next
属性作为参数传递给 list
方法来获取下一页的模型,或者您可以使用 paginate
方法自动获取页面。
# Automatic pagination using `replicate.paginate` (recommended)
models = []
for page in replicate.paginate(replicate.models.list):
models.extend(page.results)
if len(models) > 100:
break
# Manual pagination using `next` cursors
page = replicate.models.list()
while page:
models.extend(page.results)
if len(models) > 100:
break
page = replicate.models.list(page.next) if page.next else None
您还可以在Replicate上找到特色模型的集合
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
"Multimodal large language models with vision capabilities like object detection and optical character recognition (OCR)"
>>> replicate.collections.get("text-to-image").models
[<Model: stability-ai/sdxl>, ...]
创建模型
您可以为具有给定名称、可见性和硬件SKU的用户或组织创建模型
import replicate
model = replicate.models.create(
owner="your-username",
name="my-model",
visibility="public",
hardware="gpu-a40-large"
)
以下是列出在Replicate上运行模型的所有可用硬件的方法
>>> [hw.sku for hw in replicate.hardware.list()]
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
微调模型
使用训练API微调模型以使其在特定任务上表现更好。要查看当前支持微调的 语言模型,请查看Replicate的可训练语言模型集合。
如果您想微调 图像模型,请参阅Replicate的图像模型微调指南。
以下是微调Replicate上模型的方法
training = replicate.trainings.create(
model="stability-ai/sdxl",
version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
input={
"input_images": "https://my-domain/training-images.zip",
"token_string": "TOK",
"caption_prefix": "a photo of TOK",
"max_train_steps": 1000,
"use_face_detection_instead": False
},
# You need to create a model on Replicate that will be the destination for the trained version.
destination="your-username/model-name"
)
自定义客户端行为
《replicate》包导出默认共享客户端。此客户端使用由 REPLICATE_API_TOKEN
环境变量设置的 API 令牌初始化。
您可以创建自己的客户端实例,传入不同的 API 令牌值,添加自定义请求头,或控制底层 HTTPX 客户端 的行为。
import os
from replicate.client import Client
replicate = Client(
api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"]
headers={
"User-Agent": "my-app/1.0"
}
)
[!警告] 不要将认证凭据(如 API 令牌)硬编码到您的代码中。相反,在运行程序时,将它们作为环境变量传递。
开发
请参阅 CONTRIBUTING.md
项目详情
下载文件
下载适合您平台文件。如果您不确定选择哪个,请了解有关 安装软件包 的更多信息。