Source code for elektronn.training.trainer

# -*- coding: utf-8 -*-
# ELEKTRONN - Neural Network Toolkit
#
# Copyright (c) 2014 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Marius Killinger, Gregor Urban

import sys
import time
from multiprocessing import Process
import numpy as np
from matplotlib import pyplot as plt

from elektronn.utils import pprinttime
from elektronn.net import introspection as intro
from elektronn.net.netcreation import createNet
import CNNData
import traindata
import trainutils
from parallelisation import BackgroundProc


[docs]class Trainer(object): """ Object that manages Training of a CNN. Parameters ---------- config: trainutils.ConfigObj Container for all configurations Examples -------- All necessary configuration information is contained in cofig: >>> T = Trainer(config) >>> T.loadData() >>> T.createNet() >>> T.run() # The Training loop If the config options ``print_status`` and ``plot_on`` are set the CNN progress can be supervised. Control during iteration can be exercised by ctrl+c which evokes a commandline. There are various shortcuts displayed but in principle all attributes of the CNN can be accessed: >>> CNN MENU >> Debug_Run << Shortcuts: 'q' (leave interface), 'abort' (saving params), 'kill'(no saving), 'save'/'load' (opt:filename), 'sf'/' (show filters)', 'smooth' (smooth filters), 'sethist <int>', 'setlr <float>', 'setmom <float>' , 'params' print info, Change Training mode :('SGD','CG', 'RPROP', 'LBFGS') For EVERYTHING else enter your command in the command line >>> user@cnn: setlr 0.01 # Change learning rate of SGD >>> user@cnn: CG # Change Training mode CG (Optimizer will be compiled on demand) Changing Training mode... >>> user@cnn: self.config.savename # Access an attribute of ``trainerInstance``. # Inputs containing '(' or '=' will result in a print of the value 'Debug_Run' >>> user@cnn: print cnn.getDropoutRates() # To see the return of function add 'print' [0.5, 0.5] >>> user@cnn: cnn.setOptimizerParams(CG={'max_step': 0.1}) # change CG-'max_step' >>> user@cnn: q # leave interface Continuing Training Compiling CG Compiling done - in 7.206 s! """ def __init__(self, config=None): self.config = config self.data = None self.cnn = None self.CG_timeline = [] self.history = [] self.timeline = [] self.errors = [] self.saved_raw_preview = False
[docs] def reset(self): """ Resets all history of NLLs etc and randomizes CNN weights, optimiser hyper-parameters are set to initial values from config """ self.cnn.randomizeWeights() self.cnn.setOptimizerParams( self.config.SGD_params, self.config.CG_params, self.config.RPROP_params, self.config.LBFGS_params, self.config.weight_decay) self.cnn.CG_timeline = [] self.history = [] self.timeline = [] self.errors = [] self.param_vars = []
[docs] def run(self): """ Runs the Training loop until termination. Control during iteration can be exercised by ctrl+c which evokes a commandline. There are various shortcuts displayed but in principle all attributes of the CNN can be accessed: Examples -------- Using the command line >>> CNN MENU >> Debug_Run << Shortcuts: 'q' (leave interface), 'abort' (saving params), 'kill'(no saving), 'save'/'load' (opt:filename), 'sf'/' (show filters)', 'smooth' (smooth filters), 'sethist <int>', 'setlr <float>', 'setmom <float>' , 'params' print info, Change Training mode :('SGD','CG', 'RPROP', 'LBFGS') For EVERYTHING else enter your command in the command line >>> user@cnn: setlr 0.01 # Change learning rate of SGD >>> user@cnn: CG # Change Training mode CG (Optimizer will be compiled on demand) Changing Training mode... >>> user@cnn: self.config.savename # Access an attribute of ``trainerInstance``. # Inputs containing '(' or '=' will result in a print of the value 'Debug_Run' >>> user@cnn: print cnn.getDropoutRates() # To see the return of function add 'print' [0.5, 0.5] >>> user@cnn: cnn.setOptimizerParams(CG={'max_step': 0.1}) # change CG-'max_step' >>> user@cnn: q # leave interface Continuing Training Compiling CG Compiling done - in 7.206 s! """ save_name = self.config.save_name cnn = self.cnn data = self.data config = self.config schedule = self.config.LR_schedule t_passed = 0 t_per_train = 1 t_pt = 2 t_pi = 2 last_save_t = 0 save_time = config.param_save_h last_save_t2 = 0 save_time2 = config.initial_prev_h nll_ema = 0.65 nll, train_nll, valid_nll, train_error, valid_error = 0, 0, 0, 0, 0 user_termination = False plotting_proc = [] if (schedule is not None) and (schedule != []): next_LR_adjust = schedule.pop(0) else: next_LR_adjust = (None, None) pp_loss = 'MSE' if config.target == 'regression' else 'NLL' pp_err = 'std' if config.target == 'regression' else '%' # -------------------------------------------------------------------------------------------------------- if config.background_processes: n_proc = max(2, int(config.background_processes)) bg_worker = BackgroundProc(data.getbatch, n_proc=n_proc, target_kwargs=self.get_batch_kwargs) # -------------------------------------------------------------------------------------------------------- try: t0 = time.time() for i in xrange(config.n_steps): try: if config.background_processes: batch = bg_worker.get() else: batch = data.getbatch(**self.get_batch_kwargs) if config.class_weights is not None: batch = batch + (config.class_weights, ) if config.label_prop_thresh is not None: batch = batch + (config.label_prop_thresh, ) #----------------------------------------------------------------------------------------------------- nll, nll_instance, t_per_train = cnn.trainingStep(*batch, mode=config.optimizer) # Update step #----------------------------------------------------------------------------------------------------- t_per_it = time.time() - t0 t0 = time.time() if np.any(np.isnan(nll)) or np.any(np.isinf(nll)): print "The NN diverged to `nan` Loss!!!\n\ You have the chane to inspect the last used examples and the internal state of pipeline in the\ command line. The last presented training input data is `batch[0]` and the corresponding target `batch[1]`" raise KeyboardInterrupt nll_ema = 0.995 * nll_ema + 0.005 * nll # EMA t_pt = 0.8 * t_pt + 0.2 * t_per_train # EMA t_pi = 0.8 * t_pi + 0.2 * t_per_it # EMA t_passed += t_per_it batch_char = batch[1].mean() self.timeline.append([i, t_passed, nll_ema, nll, batch_char]) if (t_passed - last_save_t) / 3600 > config.param_save_h: # every hour last_save_t = t_passed time_string = '-' + str(save_time) + 'h' cnn.saveParameters(save_name + time_string + '.param', show=False) save_time += config.param_save_h if self.preview_data is not None: if (t_passed-last_save_t2)/3600 > config.prev_save_h or (t_passed/3600 > config.initial_prev_h and last_save_t2==0): # first time last_save_t2 = t_passed config.preview_kwargs['number'] = save_time2 save_time2 += config.prev_save_h try: self.previewSlice(**config.preview_kwargs) except: print "Preview Predictions failed. Are the preview raw data in the correct format?" if i == next_LR_adjust[0]: cnn.setSGDLR(np.float32(next_LR_adjust[1])) try: next_LR_adjust = schedule.pop(0) except IndexError: # list is empty next_LR_adjust = (None, None) if i % 1000 == 0: # update learning rate (exp. decay) cnn.setSGDLR(np.float32(cnn.SGD_LR.get_value() * config.LR_decay)) if (i % config.history_freq[0] == 0) and config.history_freq[0] != 0: lr = cnn.SGD_LR.get_value() self.CG_timeline = cnn.CG_timeline ### Training & Valid Errors ### nll_after = cnn.get_loss(*batch)[0] nll_gain = nll_after - nll train_nll, train_error = self.testModel('train') valid_nll, valid_error = self.testModel('valid') if config.target != 'regression': train_error *= 100 valid_error *= 100 self.errors.append([i, t_passed, train_error, valid_error]) self.history.append([i, t_passed, nll_ema, nll, train_nll, valid_nll, nll_gain, lr]) if config.target == 'malis': self.malisPreviewSlice(batch, name=i) ### Monitoring / Output ### #np.save('Backup/'+save_name+".DataHist", np.array(self.data.HIST)) ### DEBUG cnn.saveParameters(save_name + '-LAST.param', show=False) if config.plot_on and i > 30: ### TODO plotting in process gives xcb errorn on debian/ubuntu... # [p.join() for p in plotting_proc] # join before new plottings are started # plotting_proc = [] # p0 = Process(target=trainutils.plotInfo, # args=(self.timeline, self.history, self.CG_timeline, self.errors, save_name)) # plotting_proc.extend([p0]) # [p.start() for p in plotting_proc] trainutils.plotInfo(self.timeline, self.history, self.CG_timeline, self.errors, save_name) else: trainutils.saveHist(self.timeline, self.history, self.CG_timeline, self.errors, save_name) if config.print_status: out = '%05i %sm=%.3f, train=%05.2f%s, valid=%05.3f%s, prev=%04.1f, NLLdiff=%+.1e, LR=%.5f, %.1f it/s, ' % (i, pp_loss, nll_ema, train_error, pp_err, valid_error, pp_err, batch_char*100, nll_gain, cnn.SGD_LR.get_value(), 1.0 / t_pi) t = pprinttime(t_passed) print out + t # User Interface ##################################################################################### except KeyboardInterrupt: out = '%05i %s=%.5f, NLL=%.4f, train=%.5f, valid=%.5f, train=%.3f%s, valid=%.3f%s,\n\ LR=%.6f, MOM=%.6f, %.1f GPU-it/s, %.1f CPU-it/s, '\ % (i, pp_loss, nll_ema, nll, train_nll, valid_nll, train_error, pp_err, valid_error, pp_err, cnn.SGD_LR.get_value(), cnn.SGD_momentum.get_value(),1.0/t_pt, 1.0/t_pi) t = pprinttime(t_passed) print out + t # Like a command line, it must be here to access workspace variables trainutils.pprintmenu(save_name) while True: try: ret = trainutils.userInput(cnn, config.history_freq) plt.pause(0.001) if ret is None or ret == "": continue if ret == "abort": user_termination = True break elif ret == 'kill': return elif ret in ['SGD', 'RPROP', 'CG', 'LBFGS', 'Adam']: config.optimizer = ret elif ret == 'q': print "Continuing Training" break elif ret == 'sf': intro.plotFilters(cnn) else: if '(' in ret or '=' in ret: # execute statements and assignments exec(ret) else: # print value of identifiers exec('print ' + ret) except: sys.excepthook(*sys.exc_info()) # show info on error if self.config.background_processes: bg_worker.reset() t0 = time.time() # reset time after user interaction, otherwise time will appear as pause in plot # End UI ############################################################################################### if (t_passed > config.max_runtime) or user_termination: # This is in the epoch/UI loop print 'Timeout or manual Termination' break # This is OUTSIDE the training loop i.e. the last block of the function ``run`` self.cnn.saveParameters(save_name + "_end.param") trainutils.plotInfo(self.timeline, self.history, self.CG_timeline, self.errors, save_name) print 'End of Training' print '#' * 60 + '\n' + '#' * 60 + '\n' # -------------------end of run()--------------------------------------------------------------------------- except: sys.excepthook(*sys.exc_info()) # show info on error finally: if config.background_processes: bg_worker.shutdown()
[docs] def loadData(self): config = self.config if self.config.mode != 'vect-scalar' and self.config.data_class_name is None: # image training strided = ~np.any(config.MFP) and config.mode == 'img-img' self.get_batch_kwargs = dict( batch_size=config.batch_size, strided=strided, flip=config.flip_data, grey_augment_channels=config.grey_augment_channels, ret_info=config.lazy_labels, ret_example_weights=config.use_example_weights, warp_on=config.warp_on, ignore_thresh=config.example_ignore_threshold) # the source is replaced in self.testModel to be valid self.get_batch_kwargs_test = dict( batch_size=config.monitor_batch_size, strided=strided, flip=config.flip_data, grey_augment_channels=config.grey_augment_channels, ret_info=config.lazy_labels, ret_example_weights=config.use_example_weights, warp_on=False, ignore_thresh=config.example_ignore_threshold) # no warp self.data = CNNData.CNNData( config.patch_size, config.dimensions.pred_stride, config.dimensions.offset, config.n_dim, config.n_lab, config.anisotropic_data, config.mode, config.zchxy_order, config.border_mode, config.pre_process, config.upright_x, True if config.target == 'regression' else False, config.target if config.target in ['malis', 'affinity'] else False) # return affinity graph instead of boundaries self.data.addDataFromFile(config.data_path, config.label_path, config.d_files, config.l_files, config.cube_prios, config.valid_cubes, config.downsample_xy) if self.config.preview_data_path is not None: data = trainutils.h5Load(self.config.preview_data_path) if not (isinstance(data, list) or isinstance(data, (tuple, list))): #data = np.transpose(data, (1,2,0)) # this was only a hack for I data = [data, ] data = [d.astype(np.float32) / 255 for d in data] self.preview_data = data else: self.preview_data = None else: # non-image training self.get_batch_kwargs = dict(batch_size=config.batch_size) self.get_batch_kwargs.update(self.config.data_batch_kwargs) # the source is replaced in self.testModel to be valid self.get_batch_kwargs_test = dict(batch_size=config.monitor_batch_size) if isinstance(self.config.data_class_name, tuple): Data = trainutils.import_variable_from_file(*self.config.data_class_name) else: Data = getattr(traindata, self.config.data_class_name) self.data = Data(**self.config.data_load_kwargs) self.preview_data = None
[docs] def createNet(self): """ Creates CNN according to config """ n_lab = self.data.n_lab if self.config.class_weights is not None: if self.config.target == 'nll': assert len(self.config.class_weights) == n_lab,\ "The number of class weights must equal the number of classes" if self.config.mode != 'vect-scalar': # image training n_ch = self.data.n_ch self.cnn = createNet(self.config, self.config.patch_size, n_ch, n_lab, self.config.dimensions) else: # non-image training n_ch = None if self.config.rnn_layer_kwargs is not None: n_ch = self.data.n_taps # must be None if the data should be repeated self.cnn = createNet(self.config, self.data.example_shape, n_ch, n_lab, None)
[docs] def debugGetCNNBatch(self): """ Executes ``getbatch`` but with un-strided labels and always returning info. The first batch example is plotted and the whole batch is returned for inspection. """ if self.config.mode == 'img-img': batch = self.data.getbatch( self.config.monitor_batch_size, source='train', strided=False, flip=self.config.flip_data, grey_augment_channels=self.config.grey_augment_channels, ret_info=True) try: data, label, info1, info2 = batch except: data, label, seg, info1, info2 = batch if len(label.shape) == 5: # affinities (bs, 3, z, x, y) print "Plot Batch: Showing min affinity only." label = np.min(label, axis=1) if self.config.n_dim == 2: CNNData.plotTrainingTarget(data[0, 0], label[0], 1) else: i = int(self.config.dimensions.offset[0]) CNNData.plotTrainingTarget(data[0, i, 0], label[0, 0], 1) # print "Info1=",info1 # print "Info2=",info2 plt.show() plt.savefig('debugGetCNNBatch.png', bbox_inches='tight') plt.pause(0.01) plt.pause(2.0) plt.close('all') plt.pause(0.01) return data, label, info1, info2 else: print "debugGetCNNBatch(): This function is only available for 'img-img' training mode"
[docs] def testModel(self, data_source): """ Computes NLL and error/accuracy on batch with ``monitor_batch_size`` Parameters ---------- data_source: string 'train' or 'valid' Returns ------- NLL, error: """ if data_source == 'valid': if not hasattr(self.data, 'valid_d') or not hasattr(self.data.valid_d, '__len__') or len(self.data.valid_d) == 0: return np.nan, np.nan # 0, 0 kwargs = dict(self.get_batch_kwargs_test) # copy because it is modified in next line! kwargs['source'] = data_source batch = self.data.getbatch(**kwargs) y_aux = [] if self.config.class_weights is not None: y_aux.append(self.config.class_weights) if self.config.label_prop_thresh is not None: y_aux.append(self.config.label_prop_thresh) rates = self.cnn.getDropoutRates() self.cnn.setDropoutRates([0.0, ] * len(rates)) n = len(batch[0]) nll = 0 error = 0 for j in xrange(int(np.ceil(np.float(n) / self.config.batch_size))): d = batch[0][j * self.config.batch_size:(j + 1) * self.config.batch_size] # data l = batch[1][j * self.config.batch_size:(j + 1) * self.config.batch_size] # label if len(batch) > 2: aux = [] for b in batch[2:]: aux.append(b[j * self.config.batch_size:(j + 1) * self.config.batch_size]) nl, er, pred = self.cnn.get_error(d, l, *(aux + y_aux)) else: nl, er, pred = self.cnn.get_error(d, l, *y_aux) nll += nl * len(d) error += er * len(d) nll /= n error /= n self.cnn.setDropoutRates(rates) # restore old rates return nll, error
[docs] def predictAndWrite(self, raw_img, number=0, export_class='all', block_name='', z_thick=5): """ Predict and and save a slice as preview image Parameters ---------- raw_img : np.ndarray raw data in the format (ch, x, y, z) number: int/float consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. block_name: str Name/number to distinguish different raw_imges """ block_name = str(block_name) pred = self.cnn.predictDense(raw_img) # returns (k, x, y(, z)) z_sh = pred.shape[-1] if pred.shape[0] == 3: print "WARNING: hack active for affinity previews" pred[0] = pred.min(axis=0) pred = pred[:, :, :, (z_sh - z_thick) // 2:(z_sh - z_thick) // 2 + z_thick] save_name = self.config.save_name for z in xrange(pred.shape[3]): if export_class == 'all': for c in xrange(pred.shape[0]): plt.imsave('%s-pred-%s-c%i-z%i-%shrs.png' % (save_name, block_name, c, z, number), pred[c,:,:,z], cmap='gray') elif export_class in ['malis', 'affinity']: plt.imsave('%s-pred-%s-aff-z%i-%shrs.png' % (save_name, block_name, z, number), np.transpose(pred[0:6:2,:,:,z],(1,2,0)), cmap='gray') else: if isinstance(export_class, (list, tuple)): for c in export_class: plt.imsave('%s-pred-%s-c%i-z%i-%shrs.png' % (save_name, block_name, c, z, number), pred[c,:,:,z], cmap='gray') else: c = int(export_class) plt.imsave('%s-pred-%s-c%i-z%i-%shrs.png' % (save_name, block_name, c, z, number), pred[c,:,:,z], cmap='gray') if not self.saved_raw_preview: # only do once z_off = 0 if len(self.config.dimensions.offset) == 2 else int(self.config.dimensions.offset[0]) for z in xrange(pred.shape[3]): plt.imsave('%s-raw-%s-z%i.png' % (save_name, block_name, z), raw_img[0, :, :, z + z_off], cmap='gray')
[docs] def previewSliceFromTrainData(self, cube_i=0, off=(0, 0, 0), sh=(10, 400, 400), number=0, export_class='all'): """ Predict and and save a selected slice from the training data as preview Parameters ---------- cube_i: int index of source cube in CNNData off: 3-tuple of int start index of slice to cut from cube (z,x,y) sh: 3-tuple of int shape of cube to cut (z,x,y) number: int consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. """ if not self.config.mode == 'img-img': print "previewSliceFromTrainData(): This function is only available for 'img-img' training mode" return if self.cnn.n_dim == 3: min_z = self.cnn.input_shape[1] if min_z > sh[0]: sh = list(sh) sh[0] = min_z raw_img = self.data.train_d[cube_i] raw_img = raw_img[off[0]:off[0] + sh[0], :, off[1]:off[1] + sh[1], off[1]:off[1] + sh[1]] raw_img = np.transpose(raw_img, (1, 2, 3, 0)) # (z,ch,x,y) --> (ch,x,y,z) self.predictAndWrite(raw_img, number, export_class) self.saved_raw_preview = True
[docs] def previewSlice(self, number=0, export_class='all', max_z_pred=5): """ Predict and and save a data from a separately loaded file as preview Parameters ---------- number: int/float consecutive number for the save name (i.e. hours, iterations etc.) export_class: str or int 'all' writes images of all classes, otherwise only the class with index ``export_class`` (int) is saved. max_z_pred: int approximate maximal number of z-slices to produce (depends on CNN architecture) """ if not self.config.mode == 'img-img': print "previewSlice(): This function is only available for 'img-img' training mode" return assert self.preview_data is not None, "You must provide preview data in order to call this function" for example_no, raw_img in enumerate(self.preview_data): z_sh = raw_img.shape[-1] if self.cnn.n_dim == 3: strd_z = self.cnn.output_strides[0] out_z = self.cnn.output_shape[2] * strd_z min_z = self.cnn.input_shape[1] + strd_z - 1 z_thick = min_z if out_z > max_z_pred else min_z + strd_z * int(np.ceil(float(max_z_pred - out_z) / strd_z)) else: z_thick = max_z_pred assert z_thick <= z_sh, "The preview slices are too small in z-direction for this CNN" if raw_img.ndim == 3: raw_img = raw_img[None, :, :, (z_sh - z_thick) // 2:(z_sh - z_thick) // 2 + z_thick] elif raw_img.ndim == 4: raw_img = raw_img[:, :, :, (z_sh - z_thick) // 2:(z_sh - z_thick) // 2 + z_thick] self.predictAndWrite(raw_img, number, export_class, example_no, max_z_pred) self.saved_raw_preview = True
[docs] def malisPreviewSlice(self, batch, name='A'): pred = self.cnn.class_probabilities(batch[0])[0] # (6, z, x,y) malis = self.cnn.malis_stats(*batch[:3]) # nll, n_pos, n_neg, n_tot, false_splits, false_merges, rand_index, pos_count, neg_count, labels nll, n_pos, n_neg, n_tot, false_splits, false_merges, rand_index, pos_count, neg_count = malis data, aff_gt, seg_gt = batch[:3] print "NLL : ", nll print "N total: ", n_tot print "N pos : ", n_pos print "N neg : ", n_neg print "Splits : ", false_splits print "Mergers: ", false_merges print "Rand-Index: ", rand_index pred_slices = np.transpose(pred[1::2], (1, 2, 3, 0)) pos_slices = np.transpose(pos_count, (1, 2, 3, 0)) neg_slices = np.transpose(neg_count, (1, 2, 3, 0)) neg_slices = np.log(neg_slices + 1) data = data[0, :, 0] aff_gt = np.transpose(aff_gt[0], (1, 2, 3, 0)) seg_gt = seg_gt[0] trainutils.pickleSave([pred_slices, aff_gt, pos_slices, neg_slices, seg_gt, data], 'MALIS-%s.pkl' % (name, ))