You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@mxnet.apache.org by arcosf via MXNet Forum <mx...@discoursemail.com.INVALID> on 2020/08/04 17:21:11 UTC

[MXNet Forum] [Gluon] Object detection Transfer Learning


hi! 

I have solved this issue by means of using the `yolo3_darknet53_custom` model from the model zoo. This one expects a list of the new classes that will be used to train your model. Here are the modifications I did to the file (I am currently training locally and plan to move this to BYOM scripting mode):


    # classes
    classes=train_dataset.classes
    # network
    if args.dataset in ['voc','coco']:
        net_name = '_'.join(('yolo3', args.network, args.dataset))
    else:
        net_name = '_'.join(('yolo3', args.network, 'custom'))
    args.save_prefix += net_name
    # use sync bn if specified
    if args.syncbn and len(ctx) > 1:
        print('Requested sync batch normalization')
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm,
                        norm_kwargs={'num_devices': len(ctx)})
        net.reset_class(classes)
        if args.pretrained:
            print("Will load pretrained weights of COCO")
            async_net = get_model(net_name, pretrained_base=False,pretrained=True,ctx=ctx,classes=classes)
        else:
            print("Will load non pretrained weights")
            async_net = get_model(net_name, pretrained_base=False,pretrained=False,ctx=ctx,classes=classes)  # used by cpu worker
    else:
        print('No sync batch normalization will be performed')
        if args.pretrained:
            print("Will load pretrained weights of COCO")
            net = get_model(net_name, pretrained_base=True,pretrained=True,ctx=ctx,classes=classes)
        else:
            print("Will load non pretrained weights")
            net = get_model(net_name, pretrained_base=True,pretrained=False,ctx=ctx,classes=classes)


Hope this helps! Happy training!





---
[Visit Topic](https://discuss.mxnet.io/t/object-detection-transfer-learning/2477/6) or reply to this email to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.mxnet.io/email/unsubscribe/98b095ad3c91e19e35f999711cb5118028a338712a7c7405d532d54bfdf67976).