Last active
January 18, 2024 22:12
-
-
Save dvruette/72ecac9c623b89548ed3627d69acdf69 to your computer and use it in GitHub Desktop.
Faster Grid
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/grid.py b/grid.py | |
index f9f1557..8eafb91 100755 | |
--- a/grid.py | |
+++ b/grid.py | |
@@ -5,8 +5,10 @@ | |
# Written by Francois Fleuret <[email protected]> | |
-import math | |
-import torch, torchvision | |
+from concurrent.futures import ProcessPoolExecutor | |
+import os | |
+import tqdm | |
+import torch | |
import torch.nn.functional as F | |
###################################################################### | |
@@ -166,24 +168,24 @@ class GridFactory: | |
"white_smoke", | |
][:nb_colors] | |
- def generate_scene(self): | |
- nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 | |
+ def generate_scene(self, generator=None): | |
+ nb_items = torch.randint(self.max_nb_items - 1, (1,), generator=generator).item() + 2 | |
col = torch.full((self.size * self.size,), -1) | |
shp = torch.full((self.size * self.size,), -1) | |
- a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items] | |
+ a = torch.randperm(len(self.name_colors) * len(self.name_shapes), generator=generator)[:nb_items] | |
col[:nb_items] = a % len(self.name_colors) | |
shp[:nb_items] = a // len(self.name_colors) | |
- i = torch.randperm(self.size * self.size) | |
+ i = torch.randperm(self.size * self.size, generator=generator) | |
col = col[i] | |
shp = shp[i] | |
return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) | |
- def random_transformations(self, scene): | |
+ def random_transformations(self, scene, generator=None): | |
col, shp = scene | |
descriptions = [] | |
- nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() | |
- transformations = torch.randint(5, (nb_transformations,)) | |
+ nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,), generator=generator).item() | |
+ transformations = torch.randint(5, (nb_transformations,), generator=generator) | |
for t in transformations: | |
if t == 0: | |
@@ -284,11 +286,15 @@ class GridFactory: | |
return properties | |
- def generate_scene_and_questions(self): | |
+ def generate_scene_and_questions(self, seed=None): # dummy argument to be able to use executor.map | |
+ rng = torch.Generator() | |
+ if seed is not None: | |
+ rng.manual_seed(seed) | |
+ | |
while True: | |
while True: | |
- start_scene = self.generate_scene() | |
- scene, transformations = self.random_transformations(start_scene) | |
+ start_scene = self.generate_scene(generator=rng) | |
+ scene, transformations = self.random_transformations(start_scene, generator=rng) | |
true = self.all_properties(scene) | |
if len(true) >= self.nb_questions: | |
break | |
@@ -296,7 +302,7 @@ class GridFactory: | |
for a in range(10): | |
col, shp = scene | |
col, shp = col.view(-1), shp.view(-1) | |
- p = torch.randperm(col.size(0)) | |
+ p = torch.randperm(col.size(0), generator=rng) | |
col, shp = col[p], shp[p] | |
other_scene = ( | |
col.view(self.size, self.size), | |
@@ -308,8 +314,8 @@ class GridFactory: | |
# We sometime add properties from a totally different | |
# scene to have negative "there is a xxx xxx" | |
# properties | |
- if torch.rand(1).item() < 0.2: | |
- other_scene = self.generate_scene() | |
+ if torch.rand(1, generator=rng).item() < 0.2: | |
+ other_scene = self.generate_scene(generator=rng) | |
false += self.all_properties(other_scene) | |
false = list(set(false) - set(true)) | |
@@ -319,13 +325,13 @@ class GridFactory: | |
if a < 10: | |
break | |
- true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]] | |
- false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]] | |
+ true = [true[k] for k in torch.randperm(len(true), generator=rng)[: self.nb_questions]] | |
+ false = [false[k] for k in torch.randperm(len(false), generator=rng)[: self.nb_questions]] | |
true = ["<prop> " + q + " <ans> true" for q in true] | |
false = ["<prop> " + q + " <ans> false" for q in false] | |
union = true + false | |
- questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]] | |
+ questions = [union[k] for k in torch.randperm(len(union), generator=rng)[: self.nb_questions]] | |
result = " ".join( | |
["<obj> " + x for x in self.grid_positions(start_scene)] | |
@@ -335,15 +341,28 @@ class GridFactory: | |
return start_scene, scene, result | |
- def generate_samples(self, nb, progress_bar=None): | |
+ def generate_samples(self, nb, show_progress=False, num_workers="auto", seed=None): | |
result = [] | |
- r = range(nb) | |
- if progress_bar is not None: | |
- r = progress_bar(r) | |
- | |
- for _ in r: | |
- result.append(self.generate_scene_and_questions()[2]) | |
+ rng = torch.Generator() | |
+ if seed is None: | |
+ seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
+ rng.manual_seed(seed) | |
+ | |
+ if isinstance(num_workers, str): | |
+ num_workers = os.cpu_count() | |
+ | |
+ with tqdm.tqdm(total=nb, smoothing=0.01, disable=not show_progress) as pbar: | |
+ if num_workers == 1: | |
+ for _ in range(nb): | |
+ result.append(self.generate_scene_and_questions()[2]) | |
+ pbar.update() | |
+ else: | |
+ with ProcessPoolExecutor(max_workers=num_workers) as executor: | |
+ seeds = torch.randint(0, 2 ** 32 - 1, (nb,), generator=rng).tolist() | |
+ for sample in executor.map(self.generate_scene_and_questions, seeds, chunksize=32): | |
+ result.append(sample[2]) | |
+ pbar.update() | |
return result | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment