使用 DeiT 进行图像分类¶
请遵循 DeiT 代码库中的 README.md 文件,获取有关如何使用 DeiT 进行图像分类的详细信息,或者为了快速测试,首先安装所需的软件包:
pip install torch torchvision timm pandas requests
要在 Google Colab 中运行,请运行以下命令安装依赖项:
!pip install timm pandas requests
然后运行以下脚本:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.7.0+cu126
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:
Importing from timm.models.registry is deprecated, please import via timm.models
/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:
Importing from timm.models.layers is deprecated, please import via timm.layers
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00, ?B/s]
5%|5 | 18.1M/330M [00:00<00:01, 189MB/s]
13%|#3 | 44.4M/330M [00:00<00:01, 240MB/s]
24%|##3 | 77.9M/330M [00:00<00:00, 290MB/s]
32%|###1 | 106M/330M [00:00<00:00, 286MB/s]
41%|#### | 135M/330M [00:00<00:00, 291MB/s]
52%|#####1 | 171M/330M [00:00<00:00, 320MB/s]
61%|######1 | 202M/330M [00:00<00:00, 321MB/s]
71%|#######1 | 234M/330M [00:00<00:00, 328MB/s]
81%|######## | 266M/330M [00:00<00:00, 328MB/s]
91%|######### | 300M/330M [00:01<00:00, 339MB/s]
100%|##########| 330M/330M [00:01<00:00, 318MB/s]
269
输出应该是 269,根据 ImageNet 类别索引与标签文件的对应关系,它映射到 timber wolf, grey wolf, gray wolf, Canis lupus。
现在我们已经验证可以使用 DeiT 模型对图像进行分类,接下来看看如何修改模型以便它可以在 iOS 和 Android 应用上运行。