Created
December 16, 2023 08:30
-
-
Save Geson-anko/eb67f7285c78b3e5f4c2d2268df50d80 to your computer and use it in GitHub Desktop.
学習と推論を非同期に行う処理のモックアップです。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import threading | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, TensorDataset | |
class Inference: | |
def __init__(self, device: torch.device) -> None: | |
self.device = device | |
self.lock = threading.RLock() | |
self._model = None | |
@property | |
def model(self) -> nn.Module: | |
if self._model is not None: | |
return self._model | |
else: | |
raise RuntimeError("Not attached model.") | |
@model.setter | |
def model(self, m: nn.Module) -> None: | |
with self.lock: | |
self._model = m | |
self._model.eval() | |
self._model.to(self.device) | |
def attach_model_from_neural_nets(self, neural_nets: dict[str, nn.Module]) -> None: | |
self.model = copy.deepcopy(neural_nets["encoder"]) | |
@torch.inference_mode() | |
def infer(self, x: torch.Tensor) -> torch.Tensor: | |
with self.lock: | |
return self.model(x.to(self.device)) | |
class AutoEncoder(nn.Module): | |
def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None: | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.decoder(self.encoder(x)) | |
class Trainer: | |
def __init__(self, device: torch.device) -> None: | |
self._inference_model = None | |
self.device = device | |
def build(self, neural_nets: nn.ModuleDict) -> None: | |
self.net = AutoEncoder(neural_nets["encoder"], neural_nets["decoder"]) | |
self.net.to(self.device) | |
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.001) | |
self.data_loader = DataLoader(TensorDataset(torch.zeros(16, 28 * 28)), 8) | |
@property | |
def neural_nets(self) -> nn.ModuleDict: | |
if hasattr(self, "_neural_nets"): | |
return self._neural_nets | |
raise RuntimeError("Neural Nets is not attached to trainer!") | |
@neural_nets.setter | |
def neural_nets(self, n: nn.ModuleDict) -> None: | |
self._neural_nets = n | |
@property | |
def inference_model(self) -> Inference: | |
if self._inference_model is None: | |
raise RuntimeError("Inference model is not attached to trainer!") | |
return self._inference_model | |
@inference_model.setter | |
def inference_model(self, m: Inference) -> None: | |
self._inference_model = m | |
def train(self): | |
for data in self.data_loader: | |
x = data[0].to(self.device) | |
out = self.net(x) | |
loss = F.mse_loss(x, out) | |
loss.backward() | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
print("Loss:", loss.item()) | |
def sync(self) -> None: | |
trained_model = self.neural_nets["encoder"] | |
trained_model.eval() | |
trained_param = trained_model.state_dict() | |
untrained = self.inference_model.model | |
self.inference_model.model = trained_model | |
untrained.load_state_dict(trained_param) | |
untrained.train() | |
self.neural_nets["encoder"] = untrained | |
def __call__(self): | |
self.build(self.neural_nets) | |
self.train() | |
self.sync() | |
class System: | |
def __init__(self, device="cpu"): | |
encoder = nn.Sequential( | |
nn.Linear(28 * 28, 128), | |
nn.ReLU(), | |
nn.Linear(128, 32), | |
nn.ReLU(), | |
nn.Linear(32, 8), | |
) | |
decoder = nn.Sequential( | |
nn.Linear(8, 32), | |
nn.ReLU(), | |
nn.Linear(32, 128), | |
nn.ReLU(), | |
nn.Linear(128, 28 * 28), | |
) | |
neural_nets = nn.ModuleDict( | |
{ | |
"encoder": encoder, | |
"decoder": decoder, | |
} | |
) | |
self.inference_model = Inference(device) | |
self.inference_model.attach_model_from_neural_nets(neural_nets) | |
self.trainer = Trainer(device) | |
self.trainer.inference_model = self.inference_model | |
self.trainer.neural_nets = neural_nets | |
def inference_loop(self): | |
while True: | |
obs = torch.zeros(28 * 28) | |
out = self.inference_model.infer(obs) | |
print("Infered on ", id(self.inference_model.model)) | |
time.sleep(0.5) | |
def training_loop(self): | |
while True: | |
self.trainer() | |
print("Trained") | |
time.sleep(2.0) | |
def main(self): | |
inference_thread = threading.Thread(target=self.inference_loop) | |
trainning_thread = threading.Thread(target=self.training_loop) | |
inference_thread.start() | |
trainning_thread.start() | |
inference_thread.join() | |
trainning_thread.join() | |
if __name__ == "__main__": | |
system = System(device="mps") | |
system.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment