trw.train.sampler
¶
Module Contents¶
Classes¶
Base class for all Samplers. |
|
Lazily iterate the indices of a sequential batch |
|
Samples elements sequentially, always in the same order. |
|
Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
|
Samples elements randomly from a given list of indices, without replacement. |
|
Elements from a given list of list of indices are randomly drawn without replacement, |
|
Resample the samples so that class_name classes have equal probably of being sampled. |
- class trw.train.sampler.Sampler¶
Bases:
object
Base class for all Samplers.
Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators.
- abstract initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- abstract __iter__(self)¶
Returns: an iterator the return indices of the original data source
- class trw.train.sampler._SamplerSequentialIter(nb_samples, batch_size)¶
Lazily iterate the indices of a sequential batch
- __next__(self)¶
- class trw.train.sampler.SamplerSequential(batch_size=1)¶
Bases:
Sampler
Samples elements sequentially, always in the same order.
- initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- __iter__(self)¶
Returns: an iterator the return indices of the original data source
- class trw.train.sampler.SamplerRandom(replacement=False, nb_samples_to_generate=None, batch_size=1)¶
Bases:
Sampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify
num_samples
to draw.- initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- __iter__(self)¶
Returns: an iterator the return indices of the original data source
- __next__(self)¶
- class trw.train.sampler.SamplerSubsetRandom(indices)¶
Bases:
Sampler
Samples elements randomly from a given list of indices, without replacement.
- Parameters
indices (sequence) – a sequence of indices
- initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- __iter__(self)¶
Returns: an iterator the return indices of the original data source
- class trw.train.sampler.SamplerSubsetRandomByListInterleaved(indices: Sequence[Sequence[int]])¶
Bases:
Sampler
Elements from a given list of list of indices are randomly drawn without replacement, one element per list at a time.
For sequences with different sizes, the longest of the sequences will be trimmed to the size of the shortest sequence.
This can be used for example to resample without replacement imbalanced classes in a classification task.
Examples:
>>> l1 = np.asarray([1, 2]) >>> l2 = np.asarray([3, 4, 5]) >>> sampler = trw.train.SamplerSubsetRandomByListInterleaved([l1, l2]) >>> sampler.initializer(None) >>> indices = [i for i in sampler] # indices could be [1, 5, 2, 4]
- Parameters
indices – a sequence of sequence of indices
- initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- __iter__(self)¶
Returns: an iterator the return indices of the original data source
- class trw.train.sampler.SamplerClassResampling(class_name, nb_samples_to_generate, reuse_class_frequencies_across_epochs=True, batch_size=1)¶
Bases:
Sampler
Resample the samples so that class_name classes have equal probably of being sampled.
Classification problems rarely have balanced classes so it is often required to super-sample the minority class to avoid penalizing the under represented classes and help the classifier to learn good features (as opposed to learn the class distribution).
- initializer(self, data_source)¶
Initialize the sequence iteration
- Parameters
data_source – the data source to iterate
- _fit(self, classes)¶
- __next__(self)¶
- __iter__(self)¶
Returns: an iterator the return indices of the original data source