trw.arch.darts_optimizer
¶
Module Contents¶
Functions¶
|
Capture the parameters relative to DARTS & dataset |
|
Create an optimizer and scheduler for DARTS architecture search. |
|
Create an ADAM optimizer and scheduler for DARTS architecture search. |
- trw.arch.darts_optimizer._get_parameters(model, is_darts_weight_dataset_name)¶
Capture the parameters relative to DARTS & dataset
- trw.arch.darts_optimizer.create_darts_optimizers_fn(datasets, model, optimizer_fn, darts_weight_dataset_name, scheduler_fn=None)¶
Create an optimizer and scheduler for DARTS architecture search.
In particular, parameters that are derived from
trw.arch.SpecialParameter
will be handled differently:- for each dataset that is not equal to darts_weight_dataset_name, optimize all the parameters not
derived from
trw.arch.SpecialParameter
- on the dataset darts_weight_dataset_name, ONLY the parameters derived from
trw.arch.SpecialParameter
will be optimized
- on the dataset darts_weight_dataset_name, ONLY the parameters derived from
Note
if model is an instance of`ModuleDict`, then the optimizer will only consider the parameters model[dataset_name].parameters() else model.parameters()
- Parameters
datasets – a dictionary of dataset
model – the model. Should be a Module or a ModuleDict
optimizer_fn – the functor to instantiate the optimizer
scheduler_fn – the functor to instantiate the scheduler. May be None, in that case there will be no scheduler
darts_weight_dataset_name – this specifies the dataset to be used to train the DARTS cell weights. Only the parameters of the model derived from
trw.arch.SpecialParameter
will be optimized on the dataset darts_weight_dataset_name
- Returns
a dict of optimizers, one per dataset
- trw.arch.darts_optimizer.create_darts_adam_optimizers_fn(datasets, model, darts_weight_dataset_name, learning_rate, scheduler_fn=None)¶
Create an ADAM optimizer and scheduler for DARTS architecture search.
- Parameters
datasets – a dictionary of dataset
model – a model to optimize
learning_rate – the initial learning rate
scheduler_fn – a scheduler, or None
darts_weight_dataset_name – this specifies the dataset to be used to train the DARTS cell weights. Only the parameters of the model derived from
trw.arch.SpecialParameter
will be optimized on the dataset darts_weight_dataset_name
- Returns
An optimizer