trw.train.sequence_collate

Module Contents

Classes

SequenceCollate

Group the data into a sequence of dictionary of torch.Tensor

class trw.train.sequence_collate.SequenceCollate(source_split, collate_fn=collate.default_collate_fn, device=None)

Bases: trw.train.sequence.Sequence, trw.train.sequence.SequenceIterator

Group the data into a sequence of dictionary of torch.Tensor

This can be useful to combine batches of dictionaries into a single batch with all features concatenated on axis 0. Often used in conjunction of trw.train.SequenceAsyncReservoir and trw.train.SequenceMap.

subsample(self, nb_samples)

Sub-sample a sequence to a fixed number of samples.

The purpose is to obtain a smaller sequence, this is particularly useful for the export of augmentations, samples.

Parameters

nb_samples – the number of samples desired in the original sequence

Returns

a subsampled Sequence

subsample_uids(self, uids, uids_name, new_sampler=None)

Sub-sample a sequence to samples with specified UIDs.

Parameters
  • uids (list) – the uids. If new_sampler keeps the ordering, then the samples of the resampled sequence should follow uids ordering

  • uids_name (str) – the name of the UIDs

  • new_sampler (Sampler) – the sampler to be used for the subsampler sequence. If None, re-use the existing

Returns

a subsampled Sequence

__next__(self)
Returns

The next batch of data

__iter__(self)
Returns

An iterator of batches

close(self)

Special method to close and clean the resources of the sequence