# Owner(s): ["module: unknown"] import collections import unittest import torch from torch.testing._internal.common_utils import ( TestCase, run_tests, TEST_WITH_ASAN) try: import psutil HAS_PSUTIL = True except ImportError: HAS_PSUTIL = False device = torch.device('cpu') class Network(torch.nn.Module): maxp1 = torch.nn.MaxPool2d(1, 1) def forward(self, x): return self.maxp1(x) @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") @unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") class TestOpenMP_ParallelFor(TestCase): batch = 20 channels = 1 side_dim = 80 x = torch.randn([batch, channels, side_dim, side_dim], device=device) model = Network() def func(self, runs): p = psutil.Process() # warm up for 5 runs, then things should be stable for the last 5 last_rss = collections.deque(maxlen=5) for n in range(10): for i in range(runs): self.model(self.x) last_rss.append(p.memory_info().rss) return last_rss def func_rss(self, runs): last_rss = list(self.func(runs)) # Check that the sequence is not strictly increasing is_increasing = True for idx in range(len(last_rss)): if idx == 0: continue is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1]) self.assertTrue(not is_increasing, msg='memory usage is increasing, {}'.format(str(last_rss))) def test_one_thread(self): """Make sure there is no memory leak with one thread: issue gh-32284 """ torch.set_num_threads(1) self.func_rss(300) def test_n_threads(self): """Make sure there is no memory leak with many threads """ ncores = min(5, psutil.cpu_count(logical=False)) torch.set_num_threads(ncores) self.func_rss(300) if __name__ == '__main__': run_tests()