File size: 5,987 Bytes
197d4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
import torch
import torch.optim
import os
import random 

from methods import backbone
from methods.backbone_multiblock import model_dict
from data.datamgr import SimpleDataManager, SetDataManager
from methods.StyleAdv_RN_GNN import StyleAdvGNN

from options import parse_args, get_resume_file, load_warmup_state
from test_function_fwt_benchmark import test_bestmodel
from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl


def train(base_loader, val_loader,  model, start_epoch, stop_epoch, params):

  # get optimizer and checkpoint path
  optimizer = torch.optim.Adam(model.parameters())
  if not os.path.isdir(params.checkpoint_dir):
    os.makedirs(params.checkpoint_dir)

  # for validation
  max_acc = 0
  total_it = 0

  # start
  for epoch in range(start_epoch, stop_epoch):
    model.train()
    total_it = model.train_loop(epoch, base_loader, optimizer, total_it) #model are called by reference, no need to return
    model.eval()

    acc = model.test_loop( val_loader)
    if acc > max_acc :
      print("best model! save...")
      max_acc = acc
      outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
      torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
    else:
      print("GG! best accuracy {:f}".format(max_acc))

    #if ((epoch + 1) % params.save_freq==0) or (epoch==stop_epoch-1):
    if(epoch == stop_epoch-1):
      outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
      torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)

  return model


def record_test_result(params):
  acc_file_path = os.path.join(params.checkpoint_dir, 'acc.txt')
  acc_file = open(acc_file_path,'w')
  epoch_id = -1
  print('epoch', epoch_id, 'miniImagenet:', 'cub:', 'cars:', 'places:', 'plantae:', file = acc_file)
  name = params.name
  n_shot = params.n_shot
  method = params.method
  test_bestmodel(acc_file, name, method, 'miniImagenet', n_shot, epoch_id)
  test_bestmodel(acc_file, name, method, 'cub', n_shot, epoch_id)
  test_bestmodel(acc_file, name, method, 'cars', n_shot, epoch_id)
  test_bestmodel(acc_file, name, method, 'places', n_shot, epoch_id)
  test_bestmodel(acc_file, name, method, 'plantae', n_shot, epoch_id)

  acc_file.close()
  return


def record_test_result_bscdfsl(params):
  print('hhhhhhh testing for bscdfsl')
  acc_file_path = os.path.join(params.checkpoint_dir, 'acc_bscdfsl.txt')
  acc_file = open(acc_file_path,'w')
  epoch_id = -1
  print('epoch', epoch_id, 'ChestX:', 'ISIC:', 'EuroSAT:', 'CropDisease', file = acc_file)
  name = params.name
  n_shot = params.n_shot
  method = params.method
  test_bestmodel_bscdfsl(acc_file, name, method, 'ChestX', n_shot, epoch_id)
  test_bestmodel_bscdfsl(acc_file, name, method, 'ISIC', n_shot, epoch_id)
  test_bestmodel_bscdfsl(acc_file, name, method, 'EuroSAT', n_shot, epoch_id)
  test_bestmodel_bscdfsl(acc_file, name, method, 'CropDisease', n_shot, epoch_id)

  acc_file.close()
  return


# --- main function ---
if __name__=='__main__':
  #fix seed 
  seed = 0
  print("set seed = %d" % seed)
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

  # parser argument
  params = parse_args('train')

  # output and tensorboard dir
  params.tf_dir = '%s/log/%s'%(params.save_dir, params.name)
  params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name)
  if not os.path.isdir(params.checkpoint_dir):
    os.makedirs(params.checkpoint_dir)

  # dataloader
  print('\n--- prepare dataloader ---')
  print('  train with single seen domain {}'.format(params.dataset))
  base_file  = os.path.join(params.data_dir, params.dataset, 'base.json')
  val_file   = os.path.join(params.data_dir, params.dataset, 'val.json')

  # model
  print('\n--- build model ---')
  image_size = 224
  
  #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
  n_query = max(1, int(16* params.test_n_way/params.train_n_way))

  train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot)
  base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
  base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )

  test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot)
  val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
  val_loader              = val_datamgr.get_data_loader( val_file, aug = False)

  model           = StyleAdvGNN( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params)
  model = model.cuda()

  # load model
  start_epoch = params.start_epoch
  stop_epoch = params.stop_epoch
  if params.resume != '':
    resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)
    if resume_file is not None:
      tmp = torch.load(resume_file)
      start_epoch = tmp['epoch']+1
      model.load_state_dict(tmp['state'])
      print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
  else:
    if params.warmup == 'gg3b0':
      raise Exception('Must provide the pre-trained feature encoder file using --warmup option!')
    state = load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method)
    model.feature.load_state_dict(state, strict=False)

  import time
  #start =time.clock()
  start =time.perf_counter()
  # training
  print('\n--- start the training ---')
  model = train(base_loader, val_loader, model, start_epoch, stop_epoch, params)
  #end=time.clock()
  end =time.perf_counter()
  print('Running time: %s Seconds: %s Min: %s Min per epoch'%(end-start, (end-start)/60, (end-start)/60/params.stop_epoch))

  # testing
  #record_test_result(params)
  # testing bscdfsl
  #record_test_result_bscdfsl(params)