Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

如何在demo_MLSD_flask中使用训练后的模型? #28

Open
Code-Dataset opened this issue Oct 13, 2023 · 10 comments
Open

如何在demo_MLSD_flask中使用训练后的模型? #28

Code-Dataset opened this issue Oct 13, 2023 · 10 comments

Comments

@Code-Dataset
Copy link

您好,请问训练后存在workdir/models中的模型如何在demo_MLSD_flask.py中使用呢?
我运行后报错如下:
Traceback (most recent call last):
File "demo_MLSD_flask.py", line 296, in
init_worker(args)
File "demo_MLSD_flask.py", line 255, in init_worker
model = model_graph(args)
File "demo_MLSD_flask.py", line 86, in init
self.model = self.load(args.model_dir, args.model_type)
File "demo_MLSD_flask.py", line 105, in load
torch_model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
File "C:\Users\ai\AppData\Roaming\Python\Python38\site-packages\torch\nn\modules\module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MobileV2_MLSD_Tiny:
Unexpected key(s) in state_dict: "block17.weight", "block17.bias".
size mismatch for backbone.features.0.0.weight: copying a param with shape torch.Size([32, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 4, 3, 3]).

@Code-Dataset Code-Dataset changed the title 如何在demo_MLSD中使用训练后的模型? 如何在demo_MLSD_flask中使用训练后的模型? Oct 13, 2023
@syvince
Copy link

syvince commented Dec 15, 2023

@Code-Dataset 你解决了嘛
image

@BALADA-CRAM
Copy link

@Code-Dataset 你解决了嘛 image

我跟你一样的问题,请问你解决了吗

@syvince
Copy link

syvince commented Apr 22, 2024 via email

@syvince
Copy link

syvince commented Apr 22, 2024 via email

@BALADA-CRAM
Copy link

我解决了,最后我没采用flask运行,采用的gradio,但是原理是一样的,您可以看下mlsd_tiny.py这个文件,这个是我根据源码从新改的一份文件 vince @.***  

------------------ 原始邮件 ------------------ 发件人: "lhwcv/mlsd_pytorch" @.>; 发送时间: 2024年4月20日(星期六) 晚上8:44 @.>; @.@.>; 主题: Re: [lhwcv/mlsd_pytorch] 如何在demo_MLSD_flask中使用训练后的模型? (Issue #28) @Code-Dataset 你解决了嘛 我跟你一样的问题,请问你解决了吗 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

哥 您能发我邮箱吗,这里显示不出来,看不到文件
[email protected]

@ljl0311
Copy link

ljl0311 commented Apr 27, 2024

@syvince
哥,我也一直卡在这,怎么弄得哥。在这边看不到,您也能给我也发一份吗?谢了哥。
我邮箱是:[email protected]

@ganqii
Copy link

ganqii commented Sep 10, 2024

我也遇到了这个问题,可以给我发一份吗,谢谢!
邮箱:[email protected]

@syvince
Copy link

syvince commented Sep 20, 2024

mlsd_tiny.py

import os
import sys
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn import functional as F


class BlockTypeA(nn.Module):
    def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale=True):
        super(BlockTypeA, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c2, out_c2, kernel_size=1),
            nn.BatchNorm2d(out_c2),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c1, out_c1, kernel_size=1),
            nn.BatchNorm2d(out_c1),
            nn.ReLU(inplace=True)
        )
        self.upscale = upscale

    def forward(self, a, b):
        b = self.conv1(b)
        a = self.conv2(a)
        if self.upscale:
            b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
        return torch.cat((a, b), dim=1)


class BlockTypeB(nn.Module):
    def __init__(self, in_c, out_c):
        super(BlockTypeB, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv1(x) + x
        x = self.conv2(x)
        return x


class BlockTypeC(nn.Module):
    def __init__(self, in_c, out_c):
        super(BlockTypeC, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_c),
            nn.ReLU()
        )
        self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        self.channel_pad = out_planes - in_planes
        self.stride = stride
        # padding = (kernel_size - 1) // 2

        # TFLite uses slightly different padding than PyTorch
        # if stride == 2:
        #     padding = 0
        # else:
        padding = (kernel_size - 1) // 2

        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )
        self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)

    def forward(self, x):
        # # TFLite uses  different padding
        # if self.stride == 2:
        #     x = F.pad(x, (0, 1, 0, 1), "constant", 0)
        #     #print(x.shape)

        for module in self:
            if not isinstance(module, nn.MaxPool2d):
                x = module(x)
        return x


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, pretrained=True):
        """
        MobileNet V2 main class
        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
            block: Module specifying inverted residual building block for mobilenet
        """
        super(MobileNetV2, self).__init__()

        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        width_mult = 1.0
        round_nearest = 8

        inverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            # [6, 96, 3, 1],
            # [6, 160, 3, 2],
            # [6, 320, 1, 1],
        ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        self.features = nn.Sequential(*features)

        self.fpn_selected = [1, 3, 6, 10]
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

        # if pretrained:
        #    self._load_pretrained_model()

    def _forward_impl(self, x):
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        fpn_features = []
        for i, f in enumerate(self.features):
            if i > self.fpn_selected[-1]:
                break
            x = f(x)
            if i in self.fpn_selected:
                fpn_features.append(x)

        c1, c2, c3, c4 = fpn_features
        return c1, c2, c3, c4

    def forward(self, x):
        return self._forward_impl(x)

    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)


class MobileV2_MLSD_Tiny(nn.Module):
    def __init__(self, with_deconv):
        super(MobileV2_MLSD_Tiny, self).__init__()

        self.backbone = MobileNetV2(pretrained=True)

        self.block12 = BlockTypeA(in_c1=32, in_c2=64,
                                  out_c1=64, out_c2=64)
        self.block13 = BlockTypeB(128, 64)

        self.block14 = BlockTypeA(in_c1=24, in_c2=64,
                                  out_c1=32, out_c2=32)
        self.block15 = BlockTypeB(64, 64)

        self.block16 = BlockTypeC(64, 16)

        self.with_deconv = with_deconv

        if self.with_deconv:
            self.block17 = BilinearConvTranspose2d(16, 2, 1)
            self.block17.reset_parameters()

    def forward(self, x):
        c1, c2, c3, c4 = self.backbone(x)

        x = self.block12(c3, c4)
        x = self.block13(x)
        x = self.block14(c2, x)
        x = self.block15(x)
        x = self.block16(x)
        # x = x[:, 7:, :, :]
        # print(x.shape)
        if self.with_deconv:
            x = self.block17(x)
        else:
            x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)

        return x


class BilinearConvTranspose2d(nn.ConvTranspose2d):
    """A conv transpose initialized to bilinear interpolation."""

    def __init__(self, channels, stride, groups=1):
        """Set up the layer.

        Parameters
        ----------
        channels: int
            The number of input and output channels

        stride: int or tuple
            The amount of upsampling to do

        groups: int
            Set to 1 for a standard convolution. Set equal to channels to
            make sure there is no cross-talk between channels.
        """
        if isinstance(stride, int):
            stride = (stride, stride)

        assert groups in (1, channels), "Must use no grouping, " + \
                                        "or one group per channel"

        kernel_size = (2 * stride[0] - 1, 2 * stride[1] - 1)
        padding = (stride[0] - 1, stride[1] - 1)
        super().__init__(
            channels, channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=padding,
            groups=groups)

    def reset_parameters(self):
        """Reset the weight and bias."""
        nn.init.constant_(self.bias, 0)
        nn.init.constant_(self.weight, 0)
        bilinear_kernel = self.bilinear_kernel(self.stride)
        for i in range(self.in_channels):
            if self.groups == 1:
                j = i
            else:
                j = 0
            self.weight.data[i, j] = bilinear_kernel

    @staticmethod
    def bilinear_kernel(stride):
        """Generate a bilinear upsampling kernel."""
        num_dims = len(stride)

        shape = (1,) * num_dims
        bilinear_kernel = torch.ones(*shape)

        # The bilinear kernel is separable in its spatial dimensions
        # Build up the kernel channel by channel
        for channel in range(num_dims):
            channel_stride = stride[channel]
            kernel_size = 2 * channel_stride - 1
            # e.g. with stride = 4
            # delta = [-3, -2, -1, 0, 1, 2, 3]
            # channel_filter = [0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25]
            delta = torch.arange(1 - channel_stride, channel_stride)
            channel_filter = (1 - torch.abs(delta / channel_stride))
            # Apply the channel filter to the current channel
            shape = [1] * num_dims
            shape[channel] = kernel_size
            bilinear_kernel = bilinear_kernel * channel_filter.view(shape)
        return bilinear_kernel

@syvince
Copy link

syvince commented Sep 20, 2024

web.py

import cv2
import os

import numpy as np
import torch
from torch.nn import functional as F
import argparse

from mlsd_tiny import MobileV2_MLSD_Tiny
from albumentations import Normalize
import gradio as gr

from PIL import Image


def detect_weld(img, top_k, min_len, score_thresh):
    img = np.array(img)

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str,
                        default="/home/vsislab/python_work/weld_work/mlsd_pytorch/workdir/models/mobilev2_mlsd_tiny_512_bsize24/best.pth")
    # parser.add_argument()
    parser.add_argument("--input_size", type=int, help="image input size", default=512)
    # parser.add_argument("--sap_thresh", type=float, help="sAP thresh", default=10.0)
    parser.add_argument("--top_k", type=float, help="top k lines", default=top_k)  # 500
    parser.add_argument("--min_len", type=float, help="min len of line", default=min_len)  # 0
    parser.add_argument("--score_thresh", type=float, help="line score thresh", default=score_thresh)  # 0.2
    opt = parser.parse_args()

    # img = cv2.imread()
    model = MobileV2_MLSD_Tiny(with_deconv=True).cuda().eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(torch.load(opt.model_path, map_location=device), strict=True)

    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    line_info = mlsd_model(opt, model, img)
    lines = line_info["lines"]
    for line in lines:
        x0, y0, x1, y1 = map(int, line)  # 将坐标转换为整数
        cv2.line(img, (x0, y0), (x1, y1), (0, 255, 0), 4)  # 绘制线段,颜色为绿色,线宽为2
    print(line_info)
    # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return img, line_info


def mlsd_model(opt, model, img):
    # print(args)
    # print(args.model_path)

    h, w, _ = img.shape
    img = cv2.resize(img, (opt.input_size, opt.input_size))
    test_aug = Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    img = test_aug(image=img)['image']

    img = img.transpose(2, 0, 1)
    img = torch.from_numpy(img).unsqueeze(0).float().cuda()

    with torch.no_grad():
        batch_outputs = model(img)

    tp_mask = batch_outputs[:, 7:, :, :]
    center_ptss, pred_lines, scores = deccode_lines(tp_mask, opt.score_thresh, opt.min_len, opt.top_k, 3)
    pred_lines = pred_lines.detach().cpu().numpy()
    scores = scores.detach().cpu().numpy()
    pred_lines_list = []
    scores_list = []
    for line, score in zip(pred_lines, scores):
        x0, y0, x1, y1 = line

        x0 = w * x0 / (opt.input_size / 2)
        x1 = w * x1 / (opt.input_size / 2)

        y0 = h * y0 / (opt.input_size / 2)
        y1 = h * y1 / (opt.input_size / 2)

        pred_lines_list.append([x0, y0, x1, y1])
        scores_list.append(score)

    return {
        'width': w,
        'height': h,
        'lines': pred_lines_list,
        'scores': scores_list
    }


def deccode_lines(tpMap, score_thresh, len_thresh, topk_n, ksize=3):
    '''
    tpMap:
    center: tpMap[1, 0, :, :]
    displacement: tpMap[1, 1:5, :, :]
    '''
    b, c, h, w = tpMap.shape
    assert b == 1, 'only support bsize==1'
    displacement = tpMap[:, 1:5, :, :]
    center = tpMap[:, 0, :, :]
    heat = torch.sigmoid(center)
    hmax = F.max_pool2d(heat, (ksize, ksize), stride=1, padding=(ksize - 1) // 2)
    keep = (hmax == heat).float()
    heat = heat * keep
    heat = heat.reshape(-1, )

    heat = torch.where(heat < score_thresh, torch.zeros_like(heat), heat)

    scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
    valid_inx = torch.where(scores > score_thresh)
    scores = scores[valid_inx]
    indices = indices[valid_inx]

    yy = torch.floor_divide(indices, w).unsqueeze(-1)
    xx = torch.fmod(indices, w).unsqueeze(-1)
    center_ptss = torch.cat((xx, yy), dim=-1)

    start_point = center_ptss + displacement[0, :2, yy, xx].reshape(2, -1).permute(1, 0)
    end_point = center_ptss + displacement[0, 2:, yy, xx].reshape(2, -1).permute(1, 0)

    lines = torch.cat((start_point, end_point), dim=-1)

    all_lens = (end_point - start_point) ** 2
    all_lens = all_lens.sum(dim=-1)
    all_lens = torch.sqrt(all_lens)
    valid_inx = torch.where(all_lens > len_thresh)

    center_ptss = center_ptss[valid_inx]
    lines = lines[valid_inx]
    scores = scores[valid_inx]

    return center_ptss, lines, scores


def main():
    top_k = gr.Slider(0, 1000, 500, label="top_k")
    min_len = gr.Slider(0, 10, 0, label="min_len")
    score_thresh = gr.Slider(0, 1, 0.2, label="score_thresh")
    line_info = gr.Textbox(label="lines info")
    weld_demo = gr.Interface(detect_weld, [gr.Image(type="pil"), top_k, min_len, score_thresh],
                             ["image", line_info],
                             examples=[
                                 ["example/1.jpg"],
                                 ["example/2.jpg"],
                                 ["example/3.jpg"],
                                 ["example/4.jpg"]
                             ],
                             ).launch(server_name="180.201.6.110", server_port=8860)


if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--model_path", type=str,
    #                     default="/home/vsislab/python_work/weld_work/mlsd_pytorch/workdir/models/mobilev2_mlsd_tiny_512_bsize24/best.pth")
    # # parser.add_argument()
    # parser.add_argument("--input_size", type=int, help="image input size", default=512)
    # parser.add_argument("--sap_thresh", type=float, help="sAP thresh", default=10.0)
    # parser.add_argument("--top_k", type=float, help="top k lines", default=1000)
    # parser.add_argument("--min_len", type=float, help="min len of line", default=2.0)
    # parser.add_argument("--score_thresh", type=float, help="line score thresh", default=0.1)
    # opt = parser.parse_args()
    main()
    # main(args)

@syvince
Copy link

syvince commented Sep 20, 2024

@ganqii @ljl0311

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants