# Owner(s): ["module: unknown"] import threading import time import torch import unittest from torch.futures import Future from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests from typing import TypeVar T = TypeVar("T") def add_one(fut): return fut.wait() + 1 class TestFuture(TestCase): def test_set_exception(self) -> None: # This test is to ensure errors can propagate across futures. error_msg = "Intentional Value Error" value_error = ValueError(error_msg) f = Future[T]() # Set exception f.set_exception(value_error) # Exception should throw on wait with self.assertRaisesRegex(ValueError, "Intentional"): f.wait() # Exception should also throw on value f = Future() f.set_exception(value_error) with self.assertRaisesRegex(ValueError, "Intentional"): f.value() def cb(fut): fut.value() f = Future() f.set_exception(value_error) with self.assertRaisesRegex(RuntimeError, "Got the following error"): cb_fut = f.then(cb) cb_fut.wait() def test_set_exception_multithreading(self) -> None: # Ensure errors can propagate when one thread waits on future result # and the other sets it with an error. error_msg = "Intentional Value Error" value_error = ValueError(error_msg) def wait_future(f): with self.assertRaisesRegex(ValueError, "Intentional"): f.wait() f = Future[T]() t = threading.Thread(target=wait_future, args=(f, )) t.start() f.set_exception(value_error) t.join() def cb(fut): fut.value() def then_future(f): fut = f.then(cb) with self.assertRaisesRegex(RuntimeError, "Got the following error"): fut.wait() f = Future[T]() t = threading.Thread(target=then_future, args=(f, )) t.start() f.set_exception(value_error) t.join() def test_done(self) -> None: f = Future[torch.Tensor]() self.assertFalse(f.done()) f.set_result(torch.ones(2, 2)) self.assertTrue(f.done()) def test_done_exception(self) -> None: err_msg = "Intentional Value Error" def raise_exception(unused_future): raise RuntimeError(err_msg) f1 = Future[torch.Tensor]() self.assertFalse(f1.done()) f1.set_result(torch.ones(2, 2)) self.assertTrue(f1.done()) f2 = f1.then(raise_exception) self.assertTrue(f2.done()) with self.assertRaisesRegex(RuntimeError, err_msg): f2.wait() def test_wait(self) -> None: f = Future[torch.Tensor]() f.set_result(torch.ones(2, 2)) self.assertEqual(f.wait(), torch.ones(2, 2)) def test_wait_multi_thread(self) -> None: def slow_set_future(fut, value): time.sleep(0.5) fut.set_result(value) f = Future[torch.Tensor]() t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) t.start() self.assertEqual(f.wait(), torch.ones(2, 2)) t.join() def test_mark_future_twice(self) -> None: fut = Future[int]() fut.set_result(1) with self.assertRaisesRegex( RuntimeError, "Future can only be marked completed once" ): fut.set_result(1) def test_pickle_future(self): fut = Future[int]() errMsg = "Can not pickle torch.futures.Future" with TemporaryFileName() as fname: with self.assertRaisesRegex(RuntimeError, errMsg): torch.save(fut, fname) def test_then(self): fut = Future[torch.Tensor]() then_fut = fut.then(lambda x: x.wait() + 1) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) def test_chained_then(self): fut = Future[torch.Tensor]() futs = [] last_fut = fut for _ in range(20): last_fut = last_fut.then(add_one) futs.append(last_fut) fut.set_result(torch.ones(2, 2)) for i in range(len(futs)): self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) def _test_then_error(self, cb, errMsg): fut = Future[int]() then_fut = fut.then(cb) fut.set_result(5) self.assertEqual(5, fut.wait()) with self.assertRaisesRegex(RuntimeError, errMsg): then_fut.wait() def test_then_wrong_arg(self): def wrong_arg(tensor): return tensor + 1 self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int") def test_then_no_arg(self): def no_arg(): return True self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given") def test_then_raise(self): def raise_value_error(fut): raise ValueError("Expected error") self._test_then_error(raise_value_error, "Expected error") def test_add_done_callback_simple(self): callback_result = False def callback(fut): nonlocal callback_result fut.wait() callback_result = True fut = Future[torch.Tensor]() fut.add_done_callback(callback) self.assertFalse(callback_result) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) self.assertTrue(callback_result) def test_add_done_callback_maintains_callback_order(self): callback_result = 0 def callback_set1(fut): nonlocal callback_result fut.wait() callback_result = 1 def callback_set2(fut): nonlocal callback_result fut.wait() callback_result = 2 fut = Future[torch.Tensor]() fut.add_done_callback(callback_set1) fut.add_done_callback(callback_set2) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) # set2 called last, callback_result = 2 self.assertEqual(callback_result, 2) def _test_add_done_callback_error_ignored(self, cb): fut = Future[int]() fut.add_done_callback(cb) fut.set_result(5) # error msg logged to stdout self.assertEqual(5, fut.wait()) def test_add_done_callback_error_is_ignored(self): def raise_value_error(fut): raise ValueError("Expected error") self._test_add_done_callback_error_ignored(raise_value_error) def test_add_done_callback_no_arg_error_is_ignored(self): def no_arg(): return True # Adding another level of function indirection here on purpose. # Otherwise mypy will pick up on no_arg having an incompatible type and fail CI self._test_add_done_callback_error_ignored(no_arg) def test_interleaving_then_and_add_done_callback_maintains_callback_order(self): callback_result = 0 def callback_set1(fut): nonlocal callback_result fut.wait() callback_result = 1 def callback_set2(fut): nonlocal callback_result fut.wait() callback_result = 2 def callback_then(fut): nonlocal callback_result return fut.wait() + callback_result fut = Future[torch.Tensor]() fut.add_done_callback(callback_set1) then_fut = fut.then(callback_then) fut.add_done_callback(callback_set2) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) # then_fut's callback is called with callback_result = 1 self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) # set2 called last, callback_result = 2 self.assertEqual(callback_result, 2) def test_interleaving_then_and_add_done_callback_propagates_error(self): def raise_value_error(fut): raise ValueError("Expected error") fut = Future[torch.Tensor]() then_fut = fut.then(raise_value_error) fut.add_done_callback(raise_value_error) fut.set_result(torch.ones(2, 2)) # error from add_done_callback's callback is swallowed # error from then's callback is not self.assertEqual(fut.wait(), torch.ones(2, 2)) with self.assertRaisesRegex(RuntimeError, "Expected error"): then_fut.wait() def test_collect_all(self): fut1 = Future[int]() fut2 = Future[int]() fut_all = torch.futures.collect_all([fut1, fut2]) def slow_in_thread(fut, value): time.sleep(0.1) fut.set_result(value) t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) fut2.set_result(2) t.start() res = fut_all.wait() self.assertEqual(res[0].wait(), 1) self.assertEqual(res[1].wait(), 2) t.join() @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows") def test_wait_all(self): fut1 = Future[int]() fut2 = Future[int]() # No error version fut1.set_result(1) fut2.set_result(2) res = torch.futures.wait_all([fut1, fut2]) print(res) self.assertEqual(res, [1, 2]) # Version with an exception def raise_in_fut(fut): raise ValueError("Expected error") fut3 = fut1.then(raise_in_fut) with self.assertRaisesRegex(RuntimeError, "Expected error"): torch.futures.wait_all([fut3, fut2]) if __name__ == '__main__': run_tests()