{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "37e548c4-263c-4022-abbc-40ec7ef71911", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "from matplotlib import colors\n", "import numpy as np\n", "\n", "%matplotlib inline\n", "import seaborn as sns\n", "\n", "# import tutorial_code\n", "import matplotlib.pyplot as plt\n", "from matplotlib.colors import ListedColormap, Normalize\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import sys\n", "!{sys.executable} -m pip install PySwip\n", "\n", "\n", "def get_data(is_train):\n", " path = './data/' + ('training' if is_train else 'evaluation')\n", " data = {}\n", " for filename in os.listdir(path):\n", " with open(os.path.join(path, filename)) as f:\n", " data[filename.rstrip('.json')] = json.load(f)\n", " return {\n", " 'train': {k: [{\n", " 'input': tuple(tuple(r) for r in e['input']),\n", " 'output': tuple(tuple(r) for r in e['output']),\n", " } for e in v['train']] for k, v in data.items()},\n", " 'test': {k: [{\n", " 'input': tuple(tuple(r) for r in e['input']),\n", " 'output': tuple(tuple(r) for r in e['output']),\n", " } for e in v['test']] for k, v in data.items()}\n", " }\n", "\n", "\n", "def plot_task(task, solver=None):\n", " cols = [\n", " '#000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',\n", " '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'\n", " ]\n", " args = {'cmap': ListedColormap(cols), 'norm': Normalize(vmin=0, vmax=9)}\n", " n = len(task)\n", " m = 2 if solver is None else 3\n", " fig, ax = plt.subplots(m, n, figsize=(n * 4, m * 4))\n", " ofs = 0 if solver is None else 1\n", " if n == 1:\n", " ax[0].imshow(task[0]['input'], **args)\n", " ax[1 + ofs].imshow(task[0]['output'], **args)\n", " ax[0].axis('off')\n", " ax[1 + ofs].axis('off')\n", " if solver is not None:\n", " ax[1].imshow(solver(task[0]['input']), **args)\n", " ax[1].axis('off')\n", " else:\n", " for i, example in enumerate(task):\n", " ax[0, i].imshow(example['input'], **args)\n", " ax[1 + ofs, i].imshow(example['output'], **args)\n", " ax[0, i].axis('off')\n", " ax[1 + ofs, i].axis('off')\n", " if solver is not None:\n", " ax[1, i].imshow(solver(example['input']), **args)\n", " ax[1, i].axis('off')\n", " fig.set_facecolor('#1E1E1E')\n", " plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5" } }, "nbformat": 4, "nbformat_minor": 5 }