torchbearer
latest

Notes

  • The Trial Class
    • Instantiation
      • Criterions
    • Loading Data
      • Generators
      • Tensors
      • Running Without Data
    • Controlling Verbosity
  • The Metric API
    • Default Keys
    • Metric Decorators
      • Lambda Metrics
      • Metric Output - to_dict
    • Data Flow - The Metric Tree
  • The Callback API
    • Aims
    • Fluent
  • Using DistributedDataParallel with Torchbearer on CPU
    • Setup, Cleanup and Model
    • Sync Methods
    • Worker Function
    • Running
    • Running on machines with GPUs
    • Source Code
  • Using the Tensorboard Callback
    • Setup
    • Logging the Model Graph
    • Logging Batch Metrics
    • Logging Epoch Metrics
    • Source Code
  • Logging to Visdom
    • Model Setup
    • Logging Epoch and Batch Metrics
    • Visdom Client Parameters
    • Source Code

Package Reference

  • torchbearer
    • Trial
      • Trial
        • Trial.for_train_steps()
        • Trial.with_train_generator()
        • Trial.with_train_data()
        • Trial.for_val_steps()
        • Trial.with_val_generator()
        • Trial.with_val_data()
        • Trial.for_test_steps()
        • Trial.with_test_generator()
        • Trial.with_test_data()
        • Trial.for_steps()
        • Trial.with_generators()
        • Trial.with_data()
        • Trial.for_inf_train_steps()
        • Trial.for_inf_val_steps()
        • Trial.for_inf_test_steps()
        • Trial.for_inf_steps()
        • Trial.with_inf_train_loader()
        • Trial.with_loader()
        • Trial.with_closure()
        • Trial.run()
        • Trial.evaluate()
        • Trial.predict()
        • Trial.replay()
        • Trial.train()
        • Trial.eval()
        • Trial.to()
        • Trial.cuda()
        • Trial.cpu()
        • Trial.state_dict()
        • Trial.load_state_dict()
      • Batch Loaders
        • load_batch_infinite()
        • load_batch_none()
        • load_batch_predict()
        • load_batch_standard()
      • Misc
        • deep_to()
        • update_device_and_dtype()
    • State
      • State
        • State
        • StateKey
        • state_key()
      • Key List
    • Utilities
      • DatasetValidationSplitter
        • DatasetValidationSplitter.get_train_dataset()
        • DatasetValidationSplitter.get_val_dataset()
      • SubsetDataset
      • get_train_valid_sets()
      • train_valid_splitter()
      • base_closure()
  • torchbearer.callbacks
    • Base Classes
      • Callback
        • Callback.state_dict()
        • Callback.load_state_dict()
        • Callback.on_init()
        • Callback.on_start()
        • Callback.on_start_epoch()
        • Callback.on_start_training()
        • Callback.on_sample()
        • Callback.on_forward()
        • Callback.on_criterion()
        • Callback.on_backward()
        • Callback.on_step_training()
        • Callback.on_end_training()
        • Callback.on_start_validation()
        • Callback.on_sample_validation()
        • Callback.on_forward_validation()
        • Callback.on_criterion_validation()
        • Callback.on_step_validation()
        • Callback.on_end_validation()
        • Callback.on_end_epoch()
        • Callback.on_checkpoint()
        • Callback.on_end()
      • CallbackList
        • CallbackList.CALLBACK_STATES
        • CallbackList.CALLBACK_TYPES
        • CallbackList.state_dict()
        • CallbackList.load_state_dict()
        • CallbackList.copy()
        • CallbackList.append()
        • CallbackList.on_init()
        • CallbackList.on_start()
        • CallbackList.on_start_epoch()
        • CallbackList.on_start_training()
        • CallbackList.on_sample()
        • CallbackList.on_forward()
        • CallbackList.on_criterion()
        • CallbackList.on_backward()
        • CallbackList.on_step_training()
        • CallbackList.on_end_training()
        • CallbackList.on_start_validation()
        • CallbackList.on_sample_validation()
        • CallbackList.on_forward_validation()
        • CallbackList.on_criterion_validation()
        • CallbackList.on_step_validation()
        • CallbackList.on_end_validation()
        • CallbackList.on_end_epoch()
        • CallbackList.on_checkpoint()
        • CallbackList.on_end()
    • Imaging
      • Main Classes
        • CachingImagingCallback
        • FromState
        • ImagingCallback
        • MakeGrid
      • Deep Inside Convolutional Networks
        • ClassAppearanceModel
        • RANDOM
    • Model Checkpointers
      • Best
        • Best.load_state_dict()
        • Best.on_checkpoint()
        • Best.on_start()
        • Best.state_dict()
      • Interval
        • Interval.load_state_dict()
        • Interval.on_checkpoint()
        • Interval.state_dict()
      • ModelCheckpoint()
      • MostRecent
        • MostRecent.on_checkpoint()
    • Logging
      • CSVLogger
        • CSVLogger.on_end()
        • CSVLogger.on_end_epoch()
        • CSVLogger.on_start()
        • CSVLogger.on_step_training()
      • ConsolePrinter
        • ConsolePrinter.on_end_training()
        • ConsolePrinter.on_end_validation()
        • ConsolePrinter.on_step_training()
        • ConsolePrinter.on_step_validation()
      • Tqdm
        • Tqdm.on_end()
        • Tqdm.on_end_epoch()
        • Tqdm.on_end_training()
        • Tqdm.on_end_validation()
        • Tqdm.on_start()
        • Tqdm.on_start_training()
        • Tqdm.on_start_validation()
        • Tqdm.on_step_training()
        • Tqdm.on_step_validation()
    • Tensorboard, Visdom and Others
      • AbstractTensorBoard
        • AbstractTensorBoard.add_metric()
        • AbstractTensorBoard.close_writer()
        • AbstractTensorBoard.get_writer()
        • AbstractTensorBoard.on_end()
        • AbstractTensorBoard.on_start()
      • TensorBoard
        • TensorBoard.on_end()
        • TensorBoard.on_end_epoch()
        • TensorBoard.on_sample()
        • TensorBoard.on_start_epoch()
        • TensorBoard.on_step_training()
        • TensorBoard.on_step_validation()
      • TensorBoardImages
        • TensorBoardImages.on_end_epoch()
        • TensorBoardImages.on_step_validation()
      • TensorBoardProjector
        • TensorBoardProjector.on_end_epoch()
        • TensorBoardProjector.on_step_validation()
      • TensorBoardText
        • TensorBoardText.on_end()
        • TensorBoardText.on_end_epoch()
        • TensorBoardText.on_start()
        • TensorBoardText.on_start_epoch()
        • TensorBoardText.on_step_training()
        • TensorBoardText.table_formatter()
      • VisdomParams
        • VisdomParams.ENDPOINT
        • VisdomParams.ENV
        • VisdomParams.HTTP_PROXY_HOST
        • VisdomParams.HTTP_PROXY_PORT
        • VisdomParams.IPV6
        • VisdomParams.LOG_TO_FILENAME
        • VisdomParams.PORT
        • VisdomParams.RAISE_EXCEPTIONS
        • VisdomParams.SEND
        • VisdomParams.SERVER
        • VisdomParams.USE_INCOMING_SOCKET
      • close_writer()
      • get_writer()
      • LiveLossPlot
        • LiveLossPlot.on_end()
        • LiveLossPlot.on_start()
    • Early Stopping
      • EarlyStopping
        • EarlyStopping.load_state_dict()
        • EarlyStopping.on_end_epoch()
        • EarlyStopping.on_step_training()
        • EarlyStopping.state_dict()
        • EarlyStopping.step()
      • TerminateOnNaN
        • TerminateOnNaN.on_end_epoch()
        • TerminateOnNaN.on_step_training()
        • TerminateOnNaN.on_step_validation()
    • Gradient Clipping
      • GradientClipping
        • GradientClipping.on_backward()
        • GradientClipping.on_start()
      • GradientNormClipping
        • GradientNormClipping.on_backward()
        • GradientNormClipping.on_start()
    • Learning Rate Schedulers
      • CosineAnnealingLR
      • CyclicLR
      • ExponentialLR
      • LambdaLR
      • MultiStepLR
      • ReduceLROnPlateau
      • StepLR
      • TorchScheduler
        • TorchScheduler.on_end_epoch()
        • TorchScheduler.on_sample()
        • TorchScheduler.on_start()
        • TorchScheduler.on_start_training()
        • TorchScheduler.on_step_training()
    • Weight Decay
      • L1WeightDecay
      • L2WeightDecay
      • WeightDecay
        • WeightDecay.on_criterion()
        • WeightDecay.on_start()
    • Weight / Bias Initialisation
      • KaimingNormal
      • KaimingUniform
      • LsuvInit
        • LsuvInit.on_init()
      • WeightInit
        • WeightInit.on_init()
      • XavierNormal
      • XavierUniform
      • ZeroBias
    • Regularisers
      • Cutout
        • Cutout.on_sample()
      • RandomErase
        • RandomErase.on_sample()
      • CutMix
        • CutMix.on_sample()
        • CutMix.on_sample_validation()
      • Mixup
        • Mixup.RANDOM
        • Mixup.mixup_loss()
        • Mixup.on_sample()
      • BCPlus
        • BCPlus.bc_loss()
        • BCPlus.on_sample()
        • BCPlus.on_sample_validation()
      • SamplePairing
        • SamplePairing.default_policy()
        • SamplePairing.on_sample()
      • LabelSmoothingRegularisation
        • LabelSmoothingRegularisation.on_sample()
        • LabelSmoothingRegularisation.on_sample_validation()
        • LabelSmoothingRegularisation.to_one_hot()
    • Unpack State
      • unpack_state
    • Decorators
      • Main
        • on_init()
        • on_start()
        • on_start_epoch()
        • on_start_training()
        • on_sample()
        • on_forward()
        • on_criterion()
        • on_backward()
        • on_step_training()
        • on_end_training()
        • on_start_validation()
        • on_sample_validation()
        • on_forward_validation()
        • on_criterion_validation()
        • on_step_validation()
        • on_end_validation()
        • on_end_epoch()
        • on_checkpoint()
        • on_end()
      • Utility
  • torchbearer.metrics
    • Base Classes
      • Metric
        • Metric.eval()
        • Metric.process()
        • Metric.process_final()
        • Metric.reset()
        • Metric.train()
      • AdvancedMetric
        • AdvancedMetric.eval()
        • AdvancedMetric.process()
        • AdvancedMetric.process_final()
        • AdvancedMetric.process_final_train()
        • AdvancedMetric.process_final_validate()
        • AdvancedMetric.process_train()
        • AdvancedMetric.process_validate()
        • AdvancedMetric.train()
      • MetricList
        • MetricList.eval()
        • MetricList.process()
        • MetricList.process_final()
        • MetricList.reset()
        • MetricList.train()
      • MetricTree
        • MetricTree.add_child()
        • MetricTree.eval()
        • MetricTree.process()
        • MetricTree.process_final()
        • MetricTree.reset()
        • MetricTree.train()
      • add_default()
      • get_default()
    • Decorators - The Decorator API
      • default_for_key()
      • lambda_metric()
      • mean()
      • running_mean()
      • std()
      • to_dict()
      • var()
    • Metric Wrappers
      • BatchLambda
        • BatchLambda.process()
      • EpochLambda
        • EpochLambda.process_final_train()
        • EpochLambda.process_final_validate()
        • EpochLambda.process_train()
        • EpochLambda.process_validate()
        • EpochLambda.reset()
      • ToDict
        • ToDict.eval()
        • ToDict.process_final_train()
        • ToDict.process_final_validate()
        • ToDict.process_train()
        • ToDict.process_validate()
        • ToDict.reset()
        • ToDict.train()
    • Metric Aggregators
      • Mean
        • Mean.process()
        • Mean.process_final()
        • Mean.reset()
      • RunningMean
      • RunningMetric
        • RunningMetric.process_train()
        • RunningMetric.reset()
      • Std
        • Std.process_final()
      • Var
        • Var.process()
        • Var.process_final()
        • Var.reset()
    • Base Metrics
      • DefaultAccuracy
        • DefaultAccuracy.eval()
        • DefaultAccuracy.process()
        • DefaultAccuracy.process_final()
        • DefaultAccuracy.reset()
        • DefaultAccuracy.train()
      • BinaryAccuracy
      • CategoricalAccuracy
      • TopKCategoricalAccuracy
      • MeanSquaredError
      • Loss
      • Epoch
      • RocAucScore
    • Timer
      • TimerMetric
        • TimerMetric.get_timings()
        • TimerMetric.on_backward()
        • TimerMetric.on_criterion()
        • TimerMetric.on_criterion_validation()
        • TimerMetric.on_end()
        • TimerMetric.on_end_epoch()
        • TimerMetric.on_end_training()
        • TimerMetric.on_end_validation()
        • TimerMetric.on_forward()
        • TimerMetric.on_forward_validation()
        • TimerMetric.on_sample()
        • TimerMetric.on_sample_validation()
        • TimerMetric.on_start()
        • TimerMetric.on_start_epoch()
        • TimerMetric.on_start_training()
        • TimerMetric.on_start_validation()
        • TimerMetric.on_step_training()
        • TimerMetric.on_step_validation()
        • TimerMetric.process()
        • TimerMetric.reset()
        • TimerMetric.update_time()
torchbearer
  • Welcome to torchbearer’s documentation!
  • Edit on GitHub

Welcome to torchbearer’s documentation!

Notes

  • The Trial Class
  • The Metric API
  • The Callback API
  • Using DistributedDataParallel with Torchbearer on CPU
  • Using the Tensorboard Callback
  • Logging to Visdom

Package Reference

  • torchbearer
  • torchbearer.callbacks
  • torchbearer.metrics

Indices and tables

  • Index

  • Module Index

  • Search Page

Next

© Copyright Torchbearer Contributors. Revision 0e8484c9.

Built with Sphinx using a theme provided by Read the Docs.