adm18/IMPAX/Untitled.ipynb

192 lines
693 KiB
Text
Raw Normal View History

2025-09-16 05:20:19 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 23] loss: 20962.588\n",
"[2, 23] loss: 5299.285\n",
"[3, 23] loss: 3746.179\n",
"[4, 23] loss: 3133.387\n",
"[5, 23] loss: 3026.136\n",
"[6, 23] loss: 2714.172\n",
"[7, 23] loss: 2451.235\n",
"[8, 23] loss: 2314.406\n",
"[9, 23] loss: 2621.939\n",
"[10, 23] loss: 2520.008\n",
"[11, 23] loss: 2269.186\n",
"[12, 23] loss: 2110.618\n",
"[13, 23] loss: 2268.559\n",
"[14, 23] loss: 2196.927\n",
"[15, 23] loss: 2362.194\n",
"[16, 23] loss: 2588.966\n",
"[17, 23] loss: 2107.856\n",
"[18, 23] loss: 2017.883\n",
"[19, 23] loss: 2223.861\n",
"[20, 23] loss: 1975.168\n",
"Finished Training\n"
]
}
],
"source": [
"import random\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch.optim as optim\n",
"\n",
"from dataset import *\n",
"from models import *\n",
"\n",
"TEST_STEP = 4\n",
"\n",
"trainset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
"testset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
"\n",
"# print(len(trainset))\n",
"# exit()\n",
"\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_STEP,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"net = N90().to(device)\n",
"\n",
"# criterion = nn.MSELoss(reduction='sum')\n",
"criterion = nn.MSELoss()\n",
"# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
"# optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)\n",
"\n",
"optimizer = optim.Adam(net.parameters(), lr=0.01)\n",
"\n",
"\n",
"# for epoch in range(3): # 训练所有!整套!数据 3 次\n",
"# for step, (batch_x, batch_y) in enumerate(trainloader): # 每一步 loader 释放一小批数据用来学习\n",
"# # 假设这里就是你训练的地方...\n",
"\n",
"# # 打出来一些数据\n",
"# print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',\n",
"# batch_x.numpy(), '| batch y: ', batch_y.numpy())\n",
"# exit()\n",
"\n",
"\n",
"for epoch in range(20): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" inputs, labels = data[0].to(device), data[1].to(device)\n",
"\n",
" # print(inputs[0])\n",
" # print(labels[0])\n",
" # exit()\n",
" # print(inputs)\n",
" # break\n",
" # # continue\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
"\n",
" if i % 2000 == 1999: # print every 2000 mini-batches\n",
" print('[%d, %5d] loss: %.3f' %\n",
" (epoch + 1, i + 1, running_loss / 2000))\n",
" running_loss = 0.0\n",
"\n",
" print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))\n",
"\n",
"print('Finished Training')\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIYAAAZnCAYAAAAWaHwOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZAd1XUG/vXbt3kzEkJiMXacVCqrWARaEKAFJMQisRiwkTBLjJ3grQAn5TgxLsex40qc2MQ2dmzAFhgQm20hsSMhoX0FISin8vs59XPFC5uk0Sxv3/r3x/N35vSd7n5vRgIL635VU29ev+6739vnfPeccx3XdWFhYWFhYWFhYWFhYWFhYWFhcfQh8rsugIWFhYWFhYWFhYWFhYWFhYXF7waWGLKwsLCwsLCwsLCwsLCwsLA4SmGJIQsLCwsLCwsLCwsLCwsLC4ujFJYYsrCwsLCwsLCwsLCwsLCwsDhKYYkhCwsLCwsLCwsLCwsLCwsLi6MUlhiysLCwsLCwsLCwsLCwsLCwOErxthFDjuNc4DjO/+M4zv86jvO5tysfCwsLC4t3H+w7wsLCwsIiDPY9YWFhYfHOwXFd9/An6jhRAP8vgIUAfg1gF4Clruv+92HPzMLCwsLiXQX7jrCwsLCwCIN9T1hYWFi8s3i7LIZmAPhf13X/P9d1awAeAnDp25SXhYWFhcW7C/YdYWFhYWERBvuesLCwsHgHEXub0j0RwK/U918DmBl0s+M4ruM4h7UAruvCTLPba4czz7HmZVpwHe52CcvvUPIy63Yo7coydfv8WOvAsrmui0ikzY22Wi1EIhG0Wi3PvbxPf3ccR+57u/vnd4mx9uF47ifG09djee5IgTm2u10TXNd9d1W0M8b0jgDa74m3tUQWFhYW72LY9wQQiURcynUWFhYWFiNotVpotVqh74m3ixjqCMdx/hrAX/N7IpHQv/kqS7xGZd1PseQ113URi8XgOA4ajYZc0y8MkgEdyun5ZDpmGV3XRavVQjQa9S1jq9UapdTyd7968ZofWREEP7fAIKUzrA3NegYRavrTcZxRbWum22w2R7W3zj8SiUg7ApC2HEu9/GCOjVarhWQyCcdxUKvVAADJZBLlchnxeFzKWC6XkUgkPGWORqOeurVarVHjwFT+SSTpPvUb467rotlsjipzEBnCe8Pqqv9nPSKRiPzxGT1PWC+zzLyu89L9xfml66bT7FSvbhC2Lpjfg8aI2TbmOAtKr5vyOo4j/WKuE2wrPVfCiC7+VqlUOub7+wrzPWFhYWFhYaGh3xORSAS9vb1+98j/WqY1/9f3mvJxmAwwls2tbjZRu7kvrHxBsvvhxFg3cf2eC9LjguQuv34IuuZXxvGWWZdd52fqb/p/nYepKwBtfcKU44Pq7icbBl0PategfMw0zeeC0g2CX/276edu7/frw7HoE2MZd+PdHO/0vF/du9EH/NaETmNc5zUwMNCxHm8XMfQbACep7+/57TWB67p3ArgTaDP8b1M5LCwsLCyOPHR8RwDe94S1GLKwsLA4qjDm90QsFrPvCQsLC4tx4u0ihnYB+GPHcd6P9iJ+NYBl3TyoLRkIP8sa/Zu+xu/RaFTS0RYBtGgA4LF0MNMzofMMskiIRqO+9wX9ry0LeJ1WM9oiZzyWFZ3gxz76WV35WbaYlkF+uxIm+2mmRSsafb/ZrubneNuBz8Vi7eEejUZRKpU81yqVilgD0UIjk8mgXq+L1Us0GkWlUsGECRPknjBrFf1/s9kcda9f/cMs2Mw2TCaToyzROF6YNq3i9HNBFkOmNZRZH79rQRZN3VjxjIftD7K0CcrXzCfsmU73dZOmHvd+FkO6DwB/S8K3c4fvCMK43xEWFhYWFkcFDst74lAsC/Qz3VimdMq/W1kiLA19zbTeD/r97USQNVOn/M3fwnSx8VglhT3TTZ92k7Zf/2u5m/pDLBZDo9EQfZRW+alUCtVqNbDMfrpUWH061SGszYPq2cmC6XChm/qb94+nLGEWR4erjmN51m8sdVonwqzGDlVnfluIIdd1G47jfArAswCiAH7ouu7PQu73raQmbsx7zGsmyaNjwOhBYCrI9Xrdt0xmw3aKPUNos8BOndLNgO/WjUyXOeh7p7z9JoZ5TZMJ+tPPTNJ8GYa9nM2+8SuX33NB8LuPbakXat3O8Xgc5XIZqVRK7nMcR8ijRqOBY445Bvv27ZM0E4lEYLmDyuK6bZcxPVYikQii0ajvvUF1JhHZbDZD24j36TFt5s/fea3RaPgSlM1mU9rDr55+RFUnBAk03VzrBp1efOYcCxIournmt075PWNe1+tVGPn1+4KxviMsLCwsLI4uHI73RJCSFPae9nv3dqOsdUI3G89jlZ+6qd87gW6IID+MlbTz0xW7lSFNpTmovJ1IPv07w33ocCjcCG42m6JPAG1ZnPeVSiXROcLK1KlNxlp+874gUkGnY7Z5mF4ynnKMJQ2/svB/v/vDrnfTtmHkcCd9tRM6kXndklSHe+6/LcfVj7kQjuNyssRisVGxZjgIqZBqMiJoUJhMJ+/RFhlUdE2ySIN5d3oB8HdtBdSp0x3HazVDhlnXq5u8g9IP+42LWZCFiIY5OM2Fw4/YCyqLzs9cXMmk+8XZMZ8xnw0rs3lfs9lEPB5HIpHA8PCw/JZIJFCr1RCPxwFArIWy2SyANllSqVSQTCaRyWQAAIVCwVP2oHqbbdJsNj3xd6LRKGKx2CiyJigtfrIPwxYqs2x+pKPruh6LoUaj4YnBxTnJtjPb0y89fvrNw25f5J1+M+dH0O5GGLoVZvzq4QczrlLY/ZyHYfOnXC6j2WweFWZEYXCsK5mFhYVFINzfv+DTY0YsFnMZY6ibd7ufInyoO+6d8ghTLrslWMKUSP27HxlyuOBXhyDF3Q9+7RzUHn75hiFIoe+GTPIjAYLKR90hHo+LoUE0GkWxWEQ0GkU+nwfQluN6enpQLpcBtPWLarXq2aCm7NjNmDxcY7SbPtBl6ZRvt2P3cNwTdu9YxnwQ2Tge4tccZ0F5dSLkzDp1qzcFzRdeGxgYQKPRCK3Y7yz4tEYQkaAVYPPTbFxzsHKymQ1pki581nSr8SMgTPgNRDNIbxjzG+RG0i1L2Kk8h+M5kgrm/bp9zEHsV8+gfHX/mUy02bd+z42HKSWTf/DgQezZswdAe3GfNm0aMpmMuInt2LED8XgcxWIRALBo0SK4rotcLof9+/cDaLtz+aXvN6HDSErt4sjvnQKjm5ZCQfU32595aRdL8zk/y7CwF4G+R6dnpqOfCSp7GDkZJoiY46fbk+NoUdVp7IT1qd889/vNb33wI1q7KbeFhcU7g06Ki4adtxYWRw6ClEbznd3NvB6rjB2msAblF6YL+ZV5rDpKt/CTYzopu355hJEcJrRc1c29fvkHlSuorEFKud//+jndj9FoFLVaDWvWrAHQ1gumTp2KyZMnY3BwEACwbt06DA0NSfDfW2+9VYgkP0LIr55hCBu/neo5HnQaF93Ol/Hocp363m/8hNU/jBwbazt2085mOmPp47GuQePhBeyZjhYWFhYWFhYWFhYWFhYWFhZHKY4IiyGTBfezcOjm2HZzZ97vftPaxXRbM1lCHVw5aGefz+l7/Ophpu3HkJrBnFm+bhC2O9Dt7oef5YNpzWK2F9My3eIIv/gpZhl123Sqr8kGdxoXTJuWQNlsFq7rIp/Py7P1el0CT9OS5swzz8SmTZswa9YsAEA8Hkc8Hsfg4KD4DYdZp5hui7o9zSDjQe5EfmNOjx+2Ka81m83AgNQE3dbMNjfz9+sr3faMK2Vax4XtGHWDbnbAguak/mT5Ou2M+I25Tjtt5q5W2Jw209HXzNhcftZWFhYWRy66WWMsLCzeOZiW6FomYixHv/mqAwNr2UYf4NFJD/B7v5vW4n6WP2ZsQZ2/KecGWVlojMWKJgx+OkG39wMYVXdgtH6gr/sd1BKkrwTpFlo+DLM86taKJkyu4yEwiUQCxxxzDIB2/KBoNIpqtSq6wpIlS/DQQw/h4x//uDzL+ppl96unX9uZz3SysgpKn2B
"text/plain": [
"<Figure size 1440x2160 with 12 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.figsize'] = [20, 30]\n",
"\n",
"# dataiter = iter(testloader)\n",
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"output = net(images.to(device))\n",
"\n",
"for j in range(TEST_STEP):\n",
" plt.subplot(TEST_STEP,3,j*3+1)\n",
" plt.imshow(images[j][0,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,3,j*3+2)\n",
" plt.imshow(labels[j][0,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,3,j*3+3)\n",
" out = output[j]\n",
"# print(out)\n",
" plt.imshow(out[0,:,:].cpu().detach().numpy(), cmap='gray')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}