Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

检测率不正常 #9

Open
liyown opened this issue Oct 11, 2024 · 0 comments
Open

检测率不正常 #9

liyown opened this issue Oct 11, 2024 · 0 comments

Comments

@liyown
Copy link

liyown commented Oct 11, 2024

你好,非常感谢你的工作,我现在遇到问题,我按照你的代码逻辑抽出了以下代码,但是输入任何照片,输出都是非常大的,都是10-20,请问是什么原因

import cv2
import torch
from networks.ssp import ssp
from torchvision import transforms
from utils.patch import patch_img
from PIL import Image

# 加载模型
model = ssp().cuda()
model.load_state_dict(torch.load('models/midjourney.pth',weights_only=True))
model.eval()

model_sd = ssp().cuda()
model_sd.load_state_dict(torch.load('models/sd.pth',weights_only=True))
model_sd.eval()



patch_func = transforms.Lambda(
    lambda img: patch_img(img, 32, 256))

trans = transforms.Compose([
    patch_func,
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225]),
])

path = '1212.jpg'

with open(path, 'rb') as f:
    img = Image.open(f).convert('RGB')
    
img = trans(img)

# 扩充维度
img = img.unsqueeze(0)

# 预测
with torch.no_grad():
    out = model(img.cuda())
    out_sd = model_sd(img.cuda())
print(out)
print(out_sd)
out = out.ravel()
out_sd = out_sd.ravel()
print(out)
print(out_sd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant