Source code for nncore.engine.engine

# Copyright (c) Ye Liu. Licensed under the MIT License.

from collections import OrderedDict

import torch

import nncore
from nncore.nn import build_model
from nncore.optim import build_optimizer
from nncore.utils import CfgNode
from .buffer import Buffer
from .builder import build_dataloader, build_hook
from .comm import gather, is_distributed, is_main_process, sync
from .hooks import Hook
from .utils import get_checkpoint, load_checkpoint

_DEFAULT_STAGES = [
    dict(
        epochs=5,
        optimizer=dict(type='SGD', lr=1e-2, momentum=0.9, weight_decay=1e-4),
        lr_schedule=dict(type='iter', policy='cosine'),
        warmup=dict(type='iter', policy='linear', steps=500, ratio=0.001),
        validation=dict(interval=1))
]

_DEFAULT_HOOKS = [
    'TimerHook', 'LrUpdaterHook', 'OptimizerHook', 'CheckpointHook',
    'EvalHook', 'EventWriterHook'
]


[docs] @nncore.bind_getter('mode', 'max_stages', 'max_epochs', 'max_iters', 'start_iter', 'stage', 'epoch', 'iter', 'kwargs') class Engine(object): """ An engine that can take over the whole training, validation, and testing process, with all the baby-sitting works (stage control, optimizer configuration, lr scheduling, checkpoint management, metrics & tensorboard writing, etc.) done automatically. Args: model (:obj:`nn.Module` | cfg | str): The model or config of the model. The :obj:`forward` method of the model should return a dict containing a ``_avg_factor`` field indicating the number of samples in the current batch, and optionally a ``_out`` field denoting the model outputs to be collected and evaluated. data_loaders (dict | str): The configs of data loaders for training, validation, and testing. The dict should be in the format of ``dict(train=train_loader, val=val_loader, test=test_loader)``. stages (list[dict] | dict | None, optional): The stage config or list of stage configs to be scheduled. Each stage config should be a dict containing the following fields: - `epochs` (int): Number of epochs in the stage. - `optimizer` (:obj:`optim.Optimizer` | dict): The optimizer or \ an optimizer config containing the following fields: - `type` (str): Type of the optimizer, which can be accessed \ via :obj:`torch.optim` attributes, e.g. ``'SGD'``. - `configs for the optimizer, e.g.` ``lr=0.01, momentum=0.9``. - `lr_schedule` (dict, optional): The learning rate schedule \ config containing the following fields: - `type` (str): Type of the learning rate schedule. Expected \ values include ``'epoch'`` and ``'iter'``, indicating \ updating learning rates every epoch or iteration. - `policy` (str): The learning rate policy to use. Currently \ supported policies include ``step``, ``cosine``, ``exp``, \ ``poly``, and ``inv``. - `configs for the learning rate policy, e.g.` \ ``target_lr=0``. Please refer to :obj:`LrUpdaterHook` for \ full configs. - `warmup` (dict, optional): The warm-up policy config containing \ the following fields: - `type` (str): Type of the warm-up schedule. Expected values \ include ``'epoch'`` and ``'iter'``, indicating warming up \ for ``step`` epochs for iterations. - `policy` (str): The warm-up policy to use. Currently \ supported policies include ``linear``, ``exp`` and \ ``constant``. - `step` (int): Number of iterations to warm-up. - `ratio` (float): The ratio of learning rate to start with. \ Expected values are in the range of ``0 ~ 1``. - `validation` (dict, optional): The validation config containing \ the following fields: - `interval` (int, optional): The interval of performing \ validation. ``0`` means not performing validation. \ Default: ``0``. - `offset` (int, optional): The number of epochs to skip \ before counting the interval. Default: ``0``. Default: ``None``. hooks (list[:obj:`Hook` | dict | str] | None, optional): The list of extra hooks to be registered. Each hook can be represented as a :obj:`Hook`, a dict or a str. Default: ``None``. buffer_size (int, optional): Maximum size of the buffer. Default: ``100000``. logger (:obj:`logging.Logger` | str | None, optional): The logger or name of the logger to use. Default: ``None``. work_dir (str | None, optional): Path to the working directory. If not specified, the default working directory will be used. Default: ``None``. seed (int | None, optional): The random seed to use in data loaders. Default: ``None``. meta (any | None, optional): A dictionary-like object containing meta data of this engine. Default: ``None``. amp (dict | str | bool | None, optional): Whether to use automatic mixed precision training. Default: ``None``. debug (bool, optional): Whether to activate debug mode. Default: ``False``. Example: >>> # Build model >>> model = build_model() ... >>> # Build data loaders >>> train_loader = build_dataloader(split='train') >>> val_loader = build_dataloader(split='val') >>> data_loaders = dict(train=train_loader, val=val_loader) ... >>> # Configure stages: >>> # [Stage 1] Train the model for 5 epochs using Adam optimizer with >>> # a fixed learning rate (1e-3) and a linear warm-up policy. >>> # [Stage 2] Train the model for another 3 epochs using SGD with >>> # momentum optimizer and an iter-based cosine learning rate >>> # schedule. Perform validation after every training epoch. >>> stages = [ ... dict( ... epochs=5, ... optimizer=dict(type='Adam', lr=1e-3), ... warmup=dict(type='iter', policy='linear', steps=500)), ... dict( ... epochs=3, ... optimizer=dict(type='SGD', lr=1e-3, momentum=0.9), ... lr_schedule=dict(type='iter', policy='cosine'), ... validation=dict(interval=1)) ... ] ... >>> # Initialize and launch engine >>> engine = Engine(model, data_loaders, stages=stages) >>> engine.launch() """ def __init__(self, model, data_loaders, stages=None, hooks=None, buffer_size=100000, logger=None, work_dir=None, seed=None, meta=None, amp=None, debug=False, **kwargs): self.model = build_model(model, **kwargs) if 'train' not in data_loaders: data_loaders = dict(train=data_loaders) for a, b in (('val', 'test'), ('test', 'val')): if a not in data_loaders: loader = data_loaders[b if b in data_loaders else 'train'] if isinstance(loader, dict): data_loaders[a] = loader.copy() else: data_loaders[a] = loader self.data_loaders = { k: build_dataloader(v, seed=seed) for k, v in data_loaders.items() } if isinstance(stages, dict): self.stages = [stages] else: self.stages = stages or _DEFAULT_STAGES self.register_hook(_DEFAULT_HOOKS) if is_distributed(): self.register_hook('SamplerSeedHook', before='OptimizerHook') if hooks is not None: self.register_hook(hooks) time_str = nncore.get_timestamp() self.work_dir = work_dir or nncore.join('work_dirs', time_str) log_file = nncore.join(self.work_dir, time_str + '.log') self.logger = nncore.get_logger(logger, log_file=log_file) self.buffer = Buffer(max_size=buffer_size, logger=self.logger) self.reset_states() if isinstance(amp, dict): amp.setdefault('enabled', True) self.amp_cfg = amp elif isinstance(amp, str): if amp in ('fp16', 'float16'): dtype = torch.float16 elif amp in ('bf16', 'bfloat16'): dtype = torch.bfloat16 else: raise TypeError( "Amp training only supports 'float16' or 'bfloat16' data " "types, but got '{}'".format(amp)) self.amp_cfg = dict(enabled=True, dtype=dtype) else: self.amp_cfg = dict(enabled=bool(amp)) self.meta = meta self.debug = debug @property def cur_stage(self): return self.stages[self._stage] @property def epoch_in_stage(self): cumsum = 0 for stage in self.stages: if self._epoch + 1 <= cumsum + stage['epochs']: return self._epoch - cumsum cumsum += stage['epochs'] return self.stages[-1]['epochs'] @property def iter_in_stage(self): cumsum = 0 for i in range(self._stage): cumsum += len( self.data_loaders['train']) * self.stages[i]['epochs'] return self._iter - cumsum @property def iter_in_epoch(self): return self._iter - len(self.data_loaders['train']) * self._epoch def _call_hook(self, name): for hook in self.hooks.values(): getattr(hook, name)(self) def get_amp_type(self): if self.amp_cfg['enabled']: dtype = self.amp_cfg.get('dtype', torch.float16) return 'fp16' if dtype is torch.float16 else 'bf16' def reset_states(self): self.buffer.clear() self._max_stages = 0 if self.stages is None else len(self.stages) self._max_epochs = 0 if self.stages is None else sum( stage['epochs'] for stage in self.stages) self._max_iters = (len(self.data_loaders['train']) if 'train' in self.data_loaders else 0) * self._max_epochs self._start_iter = self._stage = self._epoch = self._iter = 0
[docs] def register_hook(self, hook, before=None, overwrite=True, **kwargs): """ Register a hook or a list of hooks into the engine. Args: hook (list | :obj:`Hook` | dict | str): The hook or list of hooks to be registered. Each hook can be represented as a :obj:`Hook`, a dict or a str. before (str, optional): Name of the hook to be inserted before. If not specified, the new hook will be added to the end of hook list. Default: ``None``. overwrite (bool, optional): Whether to overwrite the old hook with the same name if exists. Default: ``True``. """ if isinstance(hook, (list, tuple)): for h in hook: self.register_hook( h, before=before, overwrite=overwrite, **kwargs) return elif isinstance(hook, (dict, str)): hook = build_hook(hook, **kwargs) elif not isinstance(hook, Hook): raise TypeError( "hook must be a Hook, a dict or a str, but got '{}'".format( type(hook))) if not hasattr(self, 'hooks'): self.hooks = OrderedDict() if hook.name in self.hooks: if overwrite: keys = list(self.hooks.keys()) if before is None and keys[-1] != hook.name: before = keys[keys.index(hook.name) + 1] self.hooks.pop(hook.name) else: raise KeyError("hook '{}' exists".format(hook.name)) self.hooks[hook.name] = hook if before is not None: if before not in self.hooks: raise ValueError("hook '{}' not found".format(before)) keys = list(self.hooks.keys()) for key in keys[keys.index(before):-1]: self.hooks.move_to_end(key)
[docs] def unregister_hook(self, hook): """ Unregister a hook or a list of hooks from the engine. Args: hook (list | :obj:`Hook` | str): The hook or list of hooks to be unregistered. Each hook can be represented as a :obj:`Hook` or a str. """ if isinstance(hook, (list, tuple)): for h in hook: self.unregister_hook(h) return if isinstance(hook, Hook): hook = hook.name self.hooks.pop(hook)
[docs] def load_checkpoint(self, checkpoint, **kwargs): """ Load checkpoint from a file or an URL. Args: checkpoint (dict | str): A dict, a filename, an URL or a ``torchvision://<model_name>`` str indicating the checkpoint. """ load_checkpoint( self.model, checkpoint, map_location=next(self.model.parameters()).device, logger=self.logger, **kwargs) if isinstance(checkpoint, str): self.logger.info('Loaded checkpoint from {}'.format(checkpoint)) else: self.logger.info('Loaded checkpoint')
[docs] def resume(self, checkpoint, **kwargs): """ Resume from a checkpoint file. Args: checkpoint (dict | str): A dict, a filename or an URL indicatin the checkpoint. """ if isinstance(checkpoint, str): checkpoint = get_checkpoint( checkpoint, map_location=next(self.model.parameters()).device) if self.stages != checkpoint['meta']['stages']: self.logger.warn( 'Stages in the engine and checkpoint are mismatch:' '\n\nCurrent stages: {}\n\nCheckpoint stages: {}'.format([ c.to_dict() if isinstance(c, CfgNode) else c for c in self.stages ], checkpoint['meta']['stages'])) load_checkpoint(self.model, checkpoint, logger=self.logger, **kwargs) self._epoch = checkpoint['meta']['epoch'] self._iter = self._start_iter = checkpoint['meta']['iter'] cumsum, count = 0, 0 for stage in self.stages: if self._epoch + 1 <= cumsum + stage['epochs']: break count += 1 self._stage = count if 'optimizer' in checkpoint: self.optimizer = build_optimizer( self.cur_stage['optimizer'], params=[p for p in self.model.parameters() if p.requires_grad]) self.optimizer.load_state_dict(checkpoint['optimizer']) else: raise KeyError('optimizer not found in the checkpoint') self.logger.info('Resumed stage {}, epoch {}, iter {}'.format( self._stage + 1, self._epoch, self._iter))
def train_iter(self, data): self._call_hook('before_train_iter') device = 'cuda' if torch.cuda.is_available() else 'cpu' with torch.autocast(device, **self.amp_cfg): output = self.model(data, mode=self._mode, **self._kwargs) self.losses = {k: v for k, v in output.items() if 'loss' in k} if 'loss' not in output: self.losses['loss'] = output['loss'] = sum( v for v in self.losses.values()) for key, value in output.items(): self.buffer.update( key, value.detach().cpu() if torch.is_tensor(value) else value) self._call_hook('after_train_iter') self._iter += 1 def val_iter(self, data): self._call_hook('before_val_iter') with torch.no_grad(): output = self.model(data, mode=self._mode, **self._kwargs) if any('loss' in key for key in output) and 'loss' not in output: output['loss'] = sum(v for k, v in output.items() if 'loss' in k) for key, value in output.items(): self.buffer.update( key, value.detach().cpu() if torch.is_tensor(value) else value) self._call_hook('after_val_iter') def test_iter(self, data): with torch.no_grad(): output = self.model(data, mode=self._mode, **self._kwargs) for key, value in output.items(): self.buffer.update( key, value.detach().cpu() if torch.is_tensor(value) else value) def train_epoch(self): self._mode = 'train' self.model.train() self.data_loader = self.data_loaders[self._mode] if callable(getattr(self.data_loader.dataset, 'set_state', None)): self.data_loader.dataset.set_state(self._mode) self._call_hook('before_train_epoch') for data in self.data_loader: self.train_iter(data) self._call_hook('after_train_epoch') self._epoch += 1 def val_epoch(self): self.logger.info('Validating...') self._mode = 'val' self.model.eval() self.buffer.pop('_out', None) self.data_loader = self.data_loaders[self._mode] if callable(getattr(self.data_loader.dataset, 'set_state', None)): self.data_loader.dataset.set_state(self._mode) self._call_hook('before_val_epoch') for data in nncore.ProgressBar(self.data_loader): self.val_iter(data) self._call_hook('after_val_epoch') def test_epoch(self): self.logger.info('Evaluating...') self._mode = 'test' self.model.eval() self.buffer.pop('_out', None) self.data_loader = self.data_loaders[self._mode] if callable(getattr(self.data_loader.dataset, 'set_state', None)): self.data_loader.dataset.set_state(self._mode) for data in nncore.ProgressBar(self.data_loader): self.test_iter(data) def run_stage(self): if isinstance(self.cur_stage['optimizer'], dict): optim_cfg = self.cur_stage['optimizer'].copy() optim_type = optim_cfg.pop('type') optim_args = ['{}: {}'.format(k, v) for k, v in optim_cfg.items()] optim_str = '{}({})'.format(optim_type, ', '.join(optim_args)) else: optim_str = '{}()'.format( self.cur_stage['optimizer'].__class__.__name__) self.logger.info('Stage: {}, epochs: {}, optimizer: {}'.format( self._stage + 1, self.cur_stage['epochs'], optim_str)) if self.epoch_in_stage == 0: self.optimizer = build_optimizer( self.cur_stage['optimizer'], params=[p for p in self.model.parameters() if p.requires_grad]) self._call_hook('before_stage') for _ in range(self.cur_stage['epochs'] - self.epoch_in_stage): self.train_epoch() cfg = self.cur_stage.get('validation') if (cfg is not None and 'val' in self.data_loaders and cfg.get('interval', 0) > 0 and self.epoch_in_stage > cfg.get('offset', 0) and self.epoch_in_stage % cfg.get('interval', 0) == 0): self.val_epoch() self._call_hook('after_stage') self._stage += 1
[docs] def evaluate(self): """ Perform evaluation. This methods is expected to be called after validation or testing. """ blob = self.buffer.pop('_out') blob = gather(blob) if is_main_process(): blob = nncore.interleave(blob)[:len(self.data_loader.dataset)] cfg = self.cur_stage.get('validation') if cfg is not None: cfg = cfg.copy() cfg.pop('interval', None) cfg.pop('offset', None) else: cfg = dict() output = self.data_loader.dataset.evaluate( blob, logger=self.logger, **cfg) else: output = dict() sync() return output
[docs] def launch(self, eval=False, **kwargs): """ Launch the engine. Args: eval (bool, optional): Whether to run evaluation only. Default: ``False``. """ self._kwargs = kwargs if eval: self.test_epoch() output = self.evaluate() self.logger.info( 'Evaluation results: ' + ', '.join(['{}: {}'.format(k, v) for k, v in output.items()])) return output self.logger.info('Distributed: {}, AMP: {}, Debug: {}'.format( is_distributed(), self.get_amp_type(), self.debug)) self.logger.info('Launch engine, host: {}, work_dir: {}'.format( nncore.get_host_info(), self.work_dir)) self._call_hook('before_launch') while self._stage < self._max_stages: self.run_stage() self._call_hook('after_launch')