Created
February 10, 2020 19:06
-
-
Save isaacmg/4df4c99eca0991dfd87aa6a229b9e2d4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@TrainerBase.register('metatrainer') | |
class MetaTrainer(Trainer): | |
def __init__(self, | |
model: Model, | |
meta_model: MetaModel, | |
optimizer: torch.optim.Optimizer, | |
iterator: DataIterator, | |
train_datasets: List[Iterable[Instance]], | |
validation_datasets: Optional[Iterable[Instance]] = None, | |
# meta learner parameters | |
meta_batches: int = 200, | |
inner_steps: int = 3, | |
tasks_per_batch: int = 2, | |
batch_norm = True, | |
**kwargs) -> None: | |
""" | |
A metatrainer for doing meta-learning. It just takes a list of labeled datasets | |
and a ``DataIterator``, and uses the supplied meta-learner to learn the weights | |
for your model over some fixed number of epochs. You can also pass in a validation | |
datasets and enable early stopping. There are many other bells and whistles as well. | |
Parameters | |
---------- | |
model : ``Model``, required. | |
""" | |
# I am not calling move_to_gpu here, because if the model is | |
# not already on the GPU then the optimizer is going to be wrong. | |
super().__init__(model, optimizer, iterator, train_datasets, **kwargs) | |
self.train_data = train_datasets | |
self._validation_data = validation_datasets | |
# Meta Trainer specific params | |
self.meta_batches = meta_batches | |
self.tasks_per_batch = tasks_per_batch | |
self.inner_steps = inner_steps | |
self.step_size = .01 | |
self.batch_norm = batch_norm |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment