• Docs >
  • ignite.contrib.engines
Shortcuts

ignite.contrib.engines#

Contribution module of engines and helper tools:

ignite.contrib.engines.tbptt

Tbptt_Events

Aditional tbptt events.

create_supervised_tbptt_trainer

Create a trainer for truncated backprop through time supervised models.

ignite.contrib.engines.common

add_early_stopping_by_val_score

Method setups early stopping handler based on the score (named by metric_name) provided by evaluator.

gen_save_best_models_by_val_score

Method adds a handler to evaluator to save n_saved of best models based on the metric (named by metric_name) provided by evaluator (i.e.

save_best_model_by_val_score

Method adds a handler to evaluator to save on a disk n_saved of best models based on the metric (named by metric_name) provided by evaluator (i.e.

setup_any_logging

Deprecated function.

setup_common_distrib_training_handlers

Helper method to setup trainer with common handlers (it also supports distributed configuration):

setup_common_training_handlers

Helper method to setup trainer with common handlers (it also supports distributed configuration):

Truncated Backpropagation Through Time#

class ignite.contrib.engines.tbptt.Tbptt_Events(value)[source]#

Aditional tbptt events.

Additional events for truncated backpropagation through time dedicated trainer.

ignite.contrib.engines.tbptt.create_supervised_tbptt_trainer(model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=<function _prepare_batch>)[source]#

Create a trainer for truncated backprop through time supervised models.

Training recurrent model on long sequences is computationally intensive as it requires to process the whole sequence before getting a gradient. However, when the training loss is computed over many outputs (X to many), there is an opportunity to compute a gradient over a subsequence. This is known as truncated backpropagation through time. This supervised trainer apply gradient optimization step every tbtt_step time steps of the sequence, while backpropagating through the same tbtt_step time steps.

Parameters:
  • model (Module) – the model to train.

  • optimizer (Optimizer) – the optimizer to use.

  • loss_fn (Module) – the loss function to use.

  • tbtt_step (int) – the length of time chunks (last one may be smaller).

  • dim (int) – axis representing the time dimension.

  • device (str | None) – device type specification (default: None). Applies to batches.

  • non_blocking (bool) – if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect.

  • prepare_batch (Callable) – function that receives batch, device, non_blocking and outputs tuple of tensors (batch_x, batch_y).

Returns:

a trainer engine with supervised update function.

Return type:

Engine

Warning

The internal use of device has changed. device will now only be used to move the input data to the correct device. The model should be moved by the user before creating an optimizer.

For more information see:

Helper methods to setup trainer/evaluator#

Note

Logger setup helpers (setup_tb_logging, setup_mlflow_logging, etc.) are now implemented in ignite.handlers.logger_utils and re-exported from ignite.contrib.engines.common for backward compatibility.

ignite.contrib.engines.common.add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name, score_sign=1.0)[source]#

Method setups early stopping handler based on the score (named by metric_name) provided by evaluator. Metric value should increase in order to keep training and not early stop.

Parameters:
  • patience (int) – number of events to wait if no improvement and then stop the training.

  • evaluator (Engine) – evaluation engine used to provide the score

  • trainer (Engine) – trainer engine to stop the run if no improvement.

  • metric_name (str) – metric name to use for score evaluation. This metric should be present in evaluator.state.metrics.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Returns:

A EarlyStopping handler.

Return type:

EarlyStopping

ignite.contrib.engines.common.gen_save_best_models_by_val_score(save_handler, evaluator, models, metric_name, n_saved=3, trainer=None, tag='val', score_sign=1.0, **kwargs)[source]#

Method adds a handler to evaluator to save n_saved of best models based on the metric (named by metric_name) provided by evaluator (i.e. evaluator.state.metrics[metric_name]). Models with highest metric value will be retained. The logic of how to store objects is delegated to save_handler.

Parameters:
  • save_handler (Callable | BaseSaveHandler) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If save_handler is callable class, it can inherit of BaseSaveHandler and optionally implement remove method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk, save_handler can be defined with DiskSaver.

  • evaluator (Engine) – evaluation engine used to provide the score

  • models (Module | dict[str, torch.nn.modules.module.Module]) – model or dictionary with the object to save. Objects should have implemented state_dict and load_state_dict methods.

  • metric_name (str) – metric name to use for score evaluation. This metric should be present in evaluator.state.metrics.

  • n_saved (int) – number of best models to store

  • trainer (Engine | None) – trainer engine to fetch the epoch when saving the best model.

  • tag (str) – score name prefix: {tag}_{metric_name}. By default, tag is “val”.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

  • kwargs (Any) – optional keyword args to be passed to construct Checkpoint.

Returns:

A Checkpoint handler.

Return type:

Checkpoint

ignite.contrib.engines.common.save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag='val', score_sign=1.0, **kwargs)[source]#

Method adds a handler to evaluator to save on a disk n_saved of best models based on the metric (named by metric_name) provided by evaluator (i.e. evaluator.state.metrics[metric_name]). Models with highest metric value will be retained.

Parameters:
  • output_path (str) – output path to indicate where to save best models

  • evaluator (Engine) – evaluation engine used to provide the score

  • model (Module) – model to store

  • metric_name (str) – metric name to use for score evaluation. This metric should be present in evaluator.state.metrics.

  • n_saved (int) – number of best models to store

  • trainer (Engine | None) – trainer engine to fetch the epoch when saving the best model.

  • tag (str) – score name prefix: {tag}_{metric_name}. By default, tag is “val”.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

  • kwargs (Any) – optional keyword args to be passed to construct Checkpoint.

Returns:

A Checkpoint handler.

Return type:

Checkpoint

ignite.contrib.engines.common.setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters)[source]#

Deprecated function.

Deprecated since version 0.4.0:

  • Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc.

Parameters:
Return type:

None

ignite.contrib.engines.common.setup_common_distrib_training_handlers(trainer, train_sampler=None, to_save=None, save_every_iters=1000, output_path=None, lr_scheduler=None, with_gpu_stats=False, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, stop_on_nan=True, clear_cuda_cache=True, save_handler=None, **kwargs)#

Helper method to setup trainer with common handlers (it also supports distributed configuration):

Parameters:
  • trainer (Engine) – trainer engine. Output of trainer’s update_function should be a dictionary or sequence or a single tensor.

  • train_sampler (DistributedSampler | None) – Optional distributed sampler used to call set_epoch method on epoch started event.

  • to_save (Mapping | None) – dictionary with objects to save in the checkpoint. This argument is passed to Checkpoint instance.

  • save_every_iters (int) – saving interval. By default, to_save objects are stored each 1000 iterations.

  • output_path (str | None) – output path to indicate where to_save objects are stored. Argument is mutually exclusive with save_handler.

  • lr_scheduler (ParamScheduler | LRScheduler | None) – learning rate scheduler as native torch LRScheduler or ignite’s parameter scheduler.

  • with_gpu_stats (bool) – if True, GpuInfo is attached to the trainer. This requires pynvml<12 package to be installed.

  • output_names (Iterable[str] | None) – list of names associated with update_function output dictionary.

  • with_pbars (bool) – if True, two progress bars on epochs and optionally on iterations are attached. Default, True.

  • with_pbar_on_iters (bool) – if True, a progress bar on iterations is attached to the trainer. Default, True.

  • log_every_iters (int) – logging interval for GpuInfo and for epoch-wise progress bar. Default, 100.

  • stop_on_nan (bool) – if True, TerminateOnNan handler is added to the trainer. Default, True.

  • clear_cuda_cache (bool) – if True, torch.cuda.empty_cache() is called every end of epoch. Default, True.

  • save_handler (Callable | BaseSaveHandler | None) – Method or callable class to use to store to_save. See Checkpoint for more details. Argument is mutually exclusive with output_path.

  • kwargs (Any) – optional keyword args to be passed to construct Checkpoint.

Return type:

None

ignite.contrib.engines.common.setup_common_training_handlers(trainer, train_sampler=None, to_save=None, save_every_iters=1000, output_path=None, lr_scheduler=None, with_gpu_stats=False, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, stop_on_nan=True, clear_cuda_cache=True, save_handler=None, **kwargs)[source]#

Helper method to setup trainer with common handlers (it also supports distributed configuration):

Parameters:
  • trainer (Engine) – trainer engine. Output of trainer’s update_function should be a dictionary or sequence or a single tensor.

  • train_sampler (DistributedSampler | None) – Optional distributed sampler used to call set_epoch method on epoch started event.

  • to_save (Mapping | None) – dictionary with objects to save in the checkpoint. This argument is passed to Checkpoint instance.

  • save_every_iters (int) – saving interval. By default, to_save objects are stored each 1000 iterations.

  • output_path (str | None) – output path to indicate where to_save objects are stored. Argument is mutually exclusive with save_handler.

  • lr_scheduler (ParamScheduler | LRScheduler | None) – learning rate scheduler as native torch LRScheduler or ignite’s parameter scheduler.

  • with_gpu_stats (bool) – if True, GpuInfo is attached to the trainer. This requires pynvml<12 package to be installed.

  • output_names (Iterable[str] | None) – list of names associated with update_function output dictionary.

  • with_pbars (bool) – if True, two progress bars on epochs and optionally on iterations are attached. Default, True.

  • with_pbar_on_iters (bool) – if True, a progress bar on iterations is attached to the trainer. Default, True.

  • log_every_iters (int) – logging interval for GpuInfo and for epoch-wise progress bar. Default, 100.

  • stop_on_nan (bool) – if True, TerminateOnNan handler is added to the trainer. Default, True.

  • clear_cuda_cache (bool) – if True, torch.cuda.empty_cache() is called every end of epoch. Default, True.

  • save_handler (Callable | BaseSaveHandler | None) – Method or callable class to use to store to_save. See Checkpoint for more details. Argument is mutually exclusive with output_path.

  • kwargs (Any) – optional keyword args to be passed to construct Checkpoint.

Return type:

None

×

Search Docs