Source code for nncore.engine.hooks.base

# Copyright (c) Ye Liu. Licensed under the MIT License.

from abc import ABCMeta
from types import MethodType

import nncore

HOOK_NAMES = [
    'before_launch', 'after_launch', 'before_stage', 'after_stage',
    'before_epoch', 'after_epoch', 'before_iter', 'after_iter',
    'before_train_epoch', 'after_train_epoch', 'before_val_epoch',
    'after_val_epoch', 'before_train_iter', 'after_train_iter',
    'before_val_iter', 'after_val_iter'
]


[docs] @nncore.bind_getter('name') class Hook(metaclass=ABCMeta): """ Base class for hooks that can be registered into :obj:`Engine`. Each hook can implement several methods. In hook methods, users should provide an argument ``engine`` to access more properties about the context. All hooks will be called one by one according to the order in :obj:`engine.hooks`. """ def __init__(self, name=None): self._name = name or self.__class__.__name__ for hook_name in HOOK_NAMES: if hasattr(self, hook_name): continue token = hook_name.split('_') if len(token) == 3: def _default_hook(self, engine): getattr(self, '{}_{}'.format(token[0], token[2]))(engine) else: def _default_hook(self, engine): pass setattr(self, hook_name, MethodType(_default_hook, self)) def __eq__(self, hook): return self._name == hook.name def __repr__(self): return '{}()'.format(self._name) def every_n_stages(self, engine, n): return (engine.stage + 1) % n == 0 if n > 0 else False def every_n_epochs(self, engine, n): return (engine.epoch + 1) % n == 0 if n > 0 else False def every_n_iters(self, engine, n): return (engine.iter + 1) % n == 0 if n > 0 else False def every_n_epochs_in_stage(self, engine, n): return (engine.epoch_in_stage + 1) % n == 0 if n > 0 else False def every_n_iters_in_stage(self, engine, n): return (engine.iter_in_stage + 1) % n == 0 if n > 0 else False def every_n_iters_in_epoch(self, engine, n): return (engine.iter_in_epoch + 1) % n == 0 if n > 0 else False def first_epoch_in_stage(self, engine): return engine.epoch_in_stage == 0 def first_iter_in_stage(self, engine): return engine.iter_in_stage == 0 def first_iter_in_epoch(self, engine): return engine.iter_in_epoch == 0 def last_epoch_in_stage(self, engine): return engine.epoch_in_stage + 1 == engine.cur_stage['epochs'] def last_iter_in_stage(self, engine): return engine.iter_in_stage + 1 == len( engine.data_loaders['train']) * engine.cur_stage['epochs'] def last_iter_in_epoch(self, engine): return engine.iter_in_epoch + 1 == len(engine.data_loaders['train']) def last_stage(self, engine): return engine.stage + 1 == engine.max_stages def last_epoch(self, engine): return engine.epoch + 1 == engine.max_epochs def last_iter(self, engine): return engine.iter + 1 == engine.max_iters