You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by arcosf via MXNet Forum <> on 2020/08/04 17:21:11 UTC

[MXNet Forum] [Gluon] Object detection Transfer Learning


I have solved this issue by means of using the `yolo3_darknet53_custom` model from the model zoo. This one expects a list of the new classes that will be used to train your model. Here are the modifications I did to the file (I am currently training locally and plan to move this to BYOM scripting mode):

    # classes
    # network
    if args.dataset in ['voc','coco']:
        net_name = '_'.join(('yolo3',, args.dataset))
        net_name = '_'.join(('yolo3',, 'custom'))
    args.save_prefix += net_name
    # use sync bn if specified
    if args.syncbn and len(ctx) > 1:
        print('Requested sync batch normalization')
        net = get_model(net_name, pretrained_base=True, norm_layer=gluon.contrib.nn.SyncBatchNorm,
                        norm_kwargs={'num_devices': len(ctx)})
        if args.pretrained:
            print("Will load pretrained weights of COCO")
            async_net = get_model(net_name, pretrained_base=False,pretrained=True,ctx=ctx,classes=classes)
            print("Will load non pretrained weights")
            async_net = get_model(net_name, pretrained_base=False,pretrained=False,ctx=ctx,classes=classes)  # used by cpu worker
        print('No sync batch normalization will be performed')
        if args.pretrained:
            print("Will load pretrained weights of COCO")
            net = get_model(net_name, pretrained_base=True,pretrained=True,ctx=ctx,classes=classes)
            print("Will load non pretrained weights")
            net = get_model(net_name, pretrained_base=True,pretrained=False,ctx=ctx,classes=classes)

Hope this helps! Happy training!

[Visit Topic]( or reply to this email to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](