{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import matplotlib.pyplot as plt\n", "from matplotlib.colors import ListedColormap, Normalize\n", "\n", "\n", "def get_data(is_train):\n", " path = 'data/' + ('training' if is_train else 'devaluation')\n", " data = {}\n", " for filename in os.listdir(path):\n", " with open(os.path.join(path, filename)) as f:\n", " data[filename.rstrip('.json')] = json.load(f)\n", " return {\n", " 'train': {k: [{\n", " 'input': tuple(tuple(r) for r in e['input']),\n", " 'output': tuple(tuple(r) for r in e['output']),\n", " } for e in v['train']] for k, v in data.items()},\n", " 'test': {k: [{\n", " 'input': tuple(tuple(r) for r in e['input']),\n", " 'output': tuple(tuple(r) for r in e['output']),\n", " } for e in v['test']] for k, v in data.items()}\n", " }\n", "\n", "\n", "def plot_task(task, solver=None):\n", " cols = [\n", " '#000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',\n", " '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'\n", " ]\n", " args = {'cmap': ListedColormap(cols), 'norm': Normalize(vmin=0, vmax=9)}\n", " n = len(task)\n", " m = 2 if solver is None else 3\n", " fig, ax = plt.subplots(m, n, figsize=(n * 4, m * 4))\n", " ofs = 0 if solver is None else 1\n", " if n == 1:\n", " ax[0].imshow(task[0]['input'], **args)\n", " ax[1 + ofs].imshow(task[0]['output'], **args)\n", " ax[0].axis('off')\n", " ax[1 + ofs].axis('off')\n", " if solver is not None:\n", " ax[1].imshow(solver(task[0]['input']), **args)\n", " ax[1].axis('off')\n", " else:\n", " for i, example in enumerate(task):\n", " ax[0, i].imshow(example['input'], **args)\n", " ax[1 + ofs, i].imshow(example['output'], **args)\n", " ax[0, i].axis('off')\n", " ax[1 + ofs, i].axis('off')\n", " if solver is not None:\n", " ax[1, i].imshow(solver(example['input']), **args)\n", " ax[1, i].axis('off')\n", " fig.set_facecolor('#1E1E1E')\n", " plt.subplots_adjust(wspace=0.1, hspace=0.1)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'data/training_not'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [2]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mget_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mis_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36mget_data\u001b[0;34m(is_train)\u001b[0m\n\u001b[1;32m 8\u001b[0m path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata/\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining_not\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_train \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdevaluation\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 9\u001b[0m data \u001b[38;5;241m=\u001b[39m {}\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m filename \u001b[38;5;129;01min\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(path, filename)) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 12\u001b[0m data[filename\u001b[38;5;241m.\u001b[39mrstrip(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.json\u001b[39m\u001b[38;5;124m'\u001b[39m)] \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/training_not'" ] } ], "source": [ "data = get_data(is_train=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_task(next(iter(data['train'].values())))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.5" }, "vscode": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" } } }, "nbformat": 4, "nbformat_minor": 4 }