07: Implement Steepest Descent, abstract test function
This commit is contained in:
@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -252,7 +252,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -271,9 +271,9 @@
|
||||
"source": [
|
||||
"def GS(A, b, eps, k_max = 10000):\n",
|
||||
" \"\"\"\n",
|
||||
" Return the estimate solution x to the problem Ax = b and the number\n",
|
||||
" of iterations k it took to decrease maximum norm error below eps\n",
|
||||
" or to exceed mi\n",
|
||||
" Return the Gauss-Seidel algorithm estimate solution x to the problem\n",
|
||||
" Ax = b and the number of iterations k it took to decrease maximum\n",
|
||||
" norm error below eps or to exceed iteration maximum k_max.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" # Assert n by n matrix.\n",
|
||||
@ -292,7 +292,7 @@
|
||||
" x_cur = np.dot(linalg.inv(D), b)\n",
|
||||
" \n",
|
||||
" k = 1\n",
|
||||
" while diff(x_cur, x_prev) > eps:\n",
|
||||
" while diff(x_cur, x_prev) > eps and k < k_max:\n",
|
||||
" k += 1\n",
|
||||
" # We will have to copy, as the array elements will point to the same\n",
|
||||
" # memory otherwise, and changes to one array will change the other aswell.\n",
|
||||
@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -404,7 +404,49 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# As the three algorithm functions will have the same signature,\n",
|
||||
"# it makes sense to only write the test function once.\n",
|
||||
"\n",
|
||||
"def test_alg(alg, alg_name):\n",
|
||||
" \"\"\"\n",
|
||||
" Check that function alg returns solutions for the example system Ax = b\n",
|
||||
" within the error defined by the same eps as used for the iteration.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" A = np.array([[ 10, - 1, 2, 0],\n",
|
||||
" [- 1, 11, - 1, 3],\n",
|
||||
" [ 2, - 1, 10, - 1],\n",
|
||||
" [ 0, 3, - 1, 8]])\n",
|
||||
" b = np.array( [ 6, 25, -11, 15] )\n",
|
||||
" x_exact = linalg.solve(A, b)\n",
|
||||
"\n",
|
||||
" print(\"Starting with A =\")\n",
|
||||
" print(A)\n",
|
||||
" print(\"and b =\", b)\n",
|
||||
" print(\"We apply the {} algorithm to solve Ax = b.\".format(alg_name))\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
" eps_list = [1e-1, 1e-2, 1e-3, 1e-4]\n",
|
||||
" for eps in eps_list:\n",
|
||||
" x, k = alg(A, b, eps)\n",
|
||||
" print(\"For eps = {:.0e}\\tafter k = {:d}\\t iterations:\".format(eps, k))\n",
|
||||
" print(\"x =\\t\\t\\t\", x)\n",
|
||||
" print(\"Ax =\\t\\t\\t\", np.dot(A, x))\n",
|
||||
" print(\"diff(Ax, b) =\\t\\t\", diff(A @ x, b))\n",
|
||||
" print(\"diff(x, x_exact) =\\t\", diff(x, x_exact))\n",
|
||||
" print()\n",
|
||||
" \n",
|
||||
" assert diff(x, x_exact) < eps\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -433,24 +475,28 @@
|
||||
"We apply the Gauss-Seidel algorithm to solve Ax = b.\n",
|
||||
"\n",
|
||||
"For eps = 1e-01\tafter k = 4\t iterations:\n",
|
||||
"x =\t\t [ 0.99463393 1.99776509 -0.99803257 1.00108402]\n",
|
||||
"Ax =\t\t [ 5.95250909 24.98206671 -10.98990699 15. ]\n",
|
||||
"diff(Ax, b) =\t 0.04749090616931895\n",
|
||||
"x =\t\t\t [ 0.99463393 1.99776509 -0.99803257 1.00108402]\n",
|
||||
"Ax =\t\t\t [ 5.95250909 24.98206671 -10.98990699 15. ]\n",
|
||||
"diff(Ax, b) =\t\t 0.04749090616931895\n",
|
||||
"diff(x, x_exact) =\t 0.005366066491359844\n",
|
||||
"\n",
|
||||
"For eps = 1e-02\tafter k = 5\t iterations:\n",
|
||||
"x =\t\t [ 0.99938302 1.99982713 -0.99978549 1.00009164]\n",
|
||||
"Ax =\t\t [ 5.99443213 24.99877578 -10.99900762 15. ]\n",
|
||||
"diff(Ax, b) =\t 0.005567865937722516\n",
|
||||
"x =\t\t\t [ 0.99938302 1.99982713 -0.99978549 1.00009164]\n",
|
||||
"Ax =\t\t\t [ 5.99443213 24.99877578 -10.99900762 15. ]\n",
|
||||
"diff(Ax, b) =\t\t 0.005567865937722516\n",
|
||||
"diff(x, x_exact) =\t 0.000616975874427883\n",
|
||||
"\n",
|
||||
"For eps = 1e-03\tafter k = 6\t iterations:\n",
|
||||
"x =\t\t [ 0.99993981 1.99998904 -0.99997989 1.00000662]\n",
|
||||
"Ax =\t\t [ 5.99944928 24.99993935 -10.99991498 15. ]\n",
|
||||
"diff(Ax, b) =\t 0.000550717702960668\n",
|
||||
"x =\t\t\t [ 0.99993981 1.99998904 -0.99997989 1.00000662]\n",
|
||||
"Ax =\t\t\t [ 5.99944928 24.99993935 -10.99991498 15. ]\n",
|
||||
"diff(Ax, b) =\t\t 0.000550717702960668\n",
|
||||
"diff(x, x_exact) =\t 6.018928065554263e-05\n",
|
||||
"\n",
|
||||
"For eps = 1e-04\tafter k = 7\t iterations:\n",
|
||||
"x =\t\t [ 0.99999488 1.99999956 -0.99999836 1.00000037]\n",
|
||||
"Ax =\t\t [ 5.99995255 24.99999971 -10.99999375 15. ]\n",
|
||||
"diff(Ax, b) =\t 4.744782363452771e-05\n",
|
||||
"x =\t\t\t [ 0.99999488 1.99999956 -0.99999836 1.00000037]\n",
|
||||
"Ax =\t\t\t [ 5.99995255 24.99999971 -10.99999375 15. ]\n",
|
||||
"diff(Ax, b) =\t\t 4.744782363452771e-05\n",
|
||||
"diff(x, x_exact) =\t 5.11751035947583e-06\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
@ -462,29 +508,7 @@
|
||||
" within the error defined by the same eps as used for the iteration.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" A = np.array([[ 10, - 1, 2, 0],\n",
|
||||
" [- 1, 11, - 1, 3],\n",
|
||||
" [ 2, - 1, 10, - 1],\n",
|
||||
" [ 0, 3, - 1, 8]])\n",
|
||||
" b = np.array( [ 6, 25, -11, 15] )\n",
|
||||
" x_exact = linalg.solve(A, b)\n",
|
||||
"\n",
|
||||
" print(\"Starting with A =\")\n",
|
||||
" print(A)\n",
|
||||
" print(\"and b =\", b)\n",
|
||||
" print(\"We apply the Gauss-Seidel algorithm to solve Ax = b.\")\n",
|
||||
" print()\n",
|
||||
"\n",
|
||||
" eps_list = [1e-1, 1e-2, 1e-3, 1e-4]\n",
|
||||
" for eps in eps_list:\n",
|
||||
" x, k = GS(A, b, eps)\n",
|
||||
" print(\"For eps = {:.0e}\\tafter k = {:d}\\t iterations:\".format(eps, k))\n",
|
||||
" print(\"x =\\t\\t\", x)\n",
|
||||
" print(\"Ax =\\t\\t\", np.dot(A, x))\n",
|
||||
" print(\"diff(Ax, b) =\\t\", diff(A @ x, b))\n",
|
||||
" print()\n",
|
||||
" \n",
|
||||
" assert diff(x, x_exact) < eps\n",
|
||||
" return test_alg(GS, \"Gauss-Seidel\")\n",
|
||||
" \n",
|
||||
"test_GS()"
|
||||
]
|
||||
@ -522,7 +546,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -539,14 +563,35 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def SD(A, b, eps):\n",
|
||||
" # YOUR CODE HERE\n",
|
||||
" raise NotImplementedError()"
|
||||
"def SD(A, b, eps, k_max = 10000):\n",
|
||||
" \"\"\"\n",
|
||||
" Return the Steepest Descent algorithm estimate solution x to the problem\n",
|
||||
" Ax = b and the number of iterations k it took to decrease maximum\n",
|
||||
" norm error below eps or to exceed iteration maximum k_max.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" # Assert n by n matrix.\n",
|
||||
" assert len(A.shape) == 2 and A.shape[0] == A.shape[1]\n",
|
||||
" \n",
|
||||
" n = len(A)\n",
|
||||
" x_cur = np.zeros(n)\n",
|
||||
" x_prev = np.zeros(n)\n",
|
||||
" \n",
|
||||
" k = 0\n",
|
||||
" while diff(x_cur, x_prev) > eps and k < k_max or k == 0:\n",
|
||||
" k += 1\n",
|
||||
" x_prev = x_cur.copy()\n",
|
||||
" \n",
|
||||
" v = b - A @ x_prev\n",
|
||||
" t = np.dot(v,v)/np.dot(v, A @ v)\n",
|
||||
" x_cur = x_prev.copy() + t*v\n",
|
||||
" \n",
|
||||
" return x_cur, k"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"deletable": false,
|
||||
"nbgrader": {
|
||||
@ -561,11 +606,54 @@
|
||||
"task": false
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Starting with A =\n",
|
||||
"[[10 -1 2 0]\n",
|
||||
" [-1 11 -1 3]\n",
|
||||
" [ 2 -1 10 -1]\n",
|
||||
" [ 0 3 -1 8]]\n",
|
||||
"and b = [ 6 25 -11 15]\n",
|
||||
"We apply the Steepest Descent algorithm to solve Ax = b.\n",
|
||||
"\n",
|
||||
"For eps = 1e-01\tafter k = 4\t iterations:\n",
|
||||
"x =\t\t\t [ 0.99748613 1.98300329 -0.98904751 1.01283183]\n",
|
||||
"Ax =\t\t\t [ 6.01376302 24.84309304 -10.89133793 15.04071202]\n",
|
||||
"diff(Ax, b) =\t\t 0.15690696195356324\n",
|
||||
"diff(x, x_exact) =\t 0.016996711694861055\n",
|
||||
"\n",
|
||||
"For eps = 1e-02\tafter k = 6\t iterations:\n",
|
||||
"x =\t\t\t [ 0.99983175 1.99716093 -0.99850509 1.00217552]\n",
|
||||
"Ax =\t\t\t [ 6.00414638 24.97397012 -10.98472385 15.00739201]\n",
|
||||
"diff(Ax, b) =\t\t 0.02602988052923294\n",
|
||||
"diff(x, x_exact) =\t 0.002839069878101119\n",
|
||||
"\n",
|
||||
"For eps = 1e-03\tafter k = 9\t iterations:\n",
|
||||
"x =\t\t\t [ 0.99991029 1.99994247 -0.99999645 1.00023877]\n",
|
||||
"Ax =\t\t\t [ 5.99916754 25.00016961 -11.00032515 15.00173404]\n",
|
||||
"diff(Ax, b) =\t\t 0.0017340408511579142\n",
|
||||
"diff(x, x_exact) =\t 0.00023877424067331177\n",
|
||||
"\n",
|
||||
"For eps = 1e-04\tafter k = 11\t iterations:\n",
|
||||
"x =\t\t\t [ 0.99998551 1.99999065 -0.99999949 1.00003874]\n",
|
||||
"Ax =\t\t\t [ 5.9998655 25.00002733 -11.00005329 15.00028137]\n",
|
||||
"diff(Ax, b) =\t\t 0.0002813662634846281\n",
|
||||
"diff(x, x_exact) =\t 3.874139053650083e-05\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def test_SD():\n",
|
||||
" # YOUR CODE HERE\n",
|
||||
" raise NotImplementedError()\n",
|
||||
" \"\"\"\n",
|
||||
" Check that SD returns solutions for the example system Ax = b\n",
|
||||
" within the error defined by the same eps as used for the iteration.\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
" return test_alg(SD, \"Steepest Descent\")\n",
|
||||
" \n",
|
||||
"test_SD()"
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user