Skip to content

Commit

Permalink
change to optimized subs
Browse files Browse the repository at this point in the history
  • Loading branch information
nqdu committed Sep 18, 2024
1 parent d68e4ad commit 8f579c6
Show file tree
Hide file tree
Showing 2 changed files with 962 additions and 904 deletions.
164 changes: 144 additions & 20 deletions code_gen/aniso_code_gen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,79 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def replace_fortran_str(text):\n",
" #temp:str = text.replace(\"+ &\\n \",\" + \")\n",
" temp:str = text.replace(\" &\\n \",\"\")\n",
" temp = temp.replace(\"+\",\"+ \")\n",
" temp = temp.replace(\"-\",\"- \")\n",
" #temp:str = temp.replace(\"- &\\n \",\" - \")\n",
" \n",
"\n",
" # convert to list\n",
" a = list(temp)\n",
"\n",
" # find all operators\n",
" idx = [i for i in range(len(temp)) if temp.startswith(' + ', i)]\n",
" idx_m = [i for i in range(len(temp)) if temp.startswith(' - ', i)]\n",
" idx.extend(idx_m)\n",
" idx.sort()\n",
" idx.append(len(a) - 1)\n",
" idx1 = [False for i in range(len(idx))]\n",
" \n",
" # add change line symbol at +/-\n",
" for i in range(80,len(a),72):\n",
" # loc = \n",
" # loc = -1\n",
" # for j in range(len(idx) - 1):\n",
" # if i >= idx[j] and i < idx[j+1]:\n",
" # loc = j \n",
" # break\n",
" # if loc != -1 and idx1[loc] == False :\n",
" # s = a[idx[loc]]\n",
" # if s == ' ':\n",
" # a[idx[loc]] = ' &\\n '\n",
" # else:\n",
" # a[idx[loc]] = s + ' &\\n '\n",
" # idx1[loc] = True\n",
" # find closest\n",
" loc = -1; dist = 999999\n",
" for j in range(len(idx) - 1):\n",
" id = abs(i - idx[j])\n",
" if dist > id:\n",
" loc = j\n",
" dist = id\n",
" if loc != -1 and idx1[loc] == False :\n",
" s = a[idx[loc]]\n",
" if s == ' ':\n",
" a[idx[loc]] = ' &\\n '\n",
" else:\n",
" a[idx[loc]] = s + ' &\\n '\n",
" idx1[loc] = True\n",
" \n",
" # replace the last &\\n with \\n\n",
" if ' &\\n ' in a[-1]:\n",
" a[-1] = a[-1].replace(' &\\n ','')\n",
" \n",
" # now replace sin/cos function with corresponding arrays\n",
" f_temp = ''.join(a)\n",
" f_temp = f_temp.replace(\"cos(theta0)\",\"costh0\")\n",
" f_temp = f_temp.replace(\"sin(theta0)\",\"sinth0\")\n",
" f_temp = f_temp.replace(\"cos(dphi)\",\"cosphi\")\n",
" f_temp = f_temp.replace(\"sin(dphi)\",\"sinphi\")\n",
" f_temp = f_temp.replace(\"+ \",\"+ \")\n",
" f_temp = f_temp.replace(\"- \",\"- \")\n",
"\n",
"\n",
" return f_temp"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -148,16 +220,9 @@
" \"im\": \"aimag\",\n",
"}\n",
"\n",
"def replace_str_sin_cos(f_out:str):\n",
" f_temp = f_out.replace(\"cos(theta0)\",\"costh0\")\n",
" f_temp = f_temp.replace(\"sin(theta0)\",\"sinth0\")\n",
" f_temp = f_temp.replace(\"cos(dphi)\",\"cosphi\")\n",
" f_temp = f_temp.replace(\"sin(dphi)\",\"sinphi\")\n",
"\n",
" return f_temp\n",
"\n",
"# open a txt file to write all functions\n",
"fileid = open(\"src/tti_subs.f90\",\"w\")\n",
"fileid = open(\"../src//tti_subs.f90\",\"w\")\n",
"fileid.write(\"!===========================================================================\\n\")\n",
"fileid.write(\"!============================= AUTO CODE FROM SYMPY ========================\\n\")\n",
"fileid.write(\"!===========================================================================\\n\")\n",
Expand Down Expand Up @@ -209,7 +274,8 @@
" print(f\"k^{p}, {f[ii][jj]} {psi_lst[kk]} = {weak_expr}\")\n",
" #weak_expr.replace(\"cos(theta0)\",\"costh0\")\n",
" f_out = sp.fcode(weak_expr,standard=2008,source_format='free',user_functions=funclist,assign_to=\"temp(:)\")\n",
" f_out = replace_str_sin_cos(f_out)\n",
" #display(f_out)\n",
" f_out = replace_fortran_str(f_out)\n",
" \n",
" # cache temporary arrays\n",
" fileid.write(f\" ! k^{p}, {f[ii][jj]} {psi_lst[kk]}\\n\")\n",
Expand Down Expand Up @@ -272,7 +338,7 @@
" if param == k0:\n",
" param_out = 'wvnm'\n",
" f_out = sp.fcode(expr,standard=2008,source_format='free',user_functions=funclist,assign_to=f\"temp(:)\")\n",
" f_out = replace_str_sin_cos(f_out)\n",
" f_out = replace_fortran_str(f_out)\n",
" fileid.write(f\" {f_out}\\n\")\n",
" fileid.write(f\" K{param_out}(:) = real(temp(:),kind=dp)\\n\\n\")\n",
"\n",
Expand All @@ -286,7 +352,7 @@
" dLag_dkvec += -I * sp.conjugate(s[i]) * c[i,m,p,j] * (-I * k[j] * s[p] + e3[j] * ds[p]) \\\n",
" - (I * k[j] * sp.conjugate(s[i]) + e3[j] * sp.conjugate(ds[i])) * c[i,j,p,m] * (-I * s[p]) \n",
" f_out = sp.fcode(dLag_dkvec,standard=2008,source_format='free',user_functions=funclist,assign_to=f\"temp(:)\")\n",
" f_out = replace_str_sin_cos(f_out)\n",
" f_out = replace_fortran_str(f_out)\n",
" fileid.write(f\" {f_out}\\n\")\n",
" fileid.write(f\" dL_dkv(:,{m+1}) = real(temp(:),kind=dp)\\n\")\n",
"\n",
Expand All @@ -299,25 +365,23 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 2 U \\omega \\rho \\overline{U} + 2 V \\omega \\rho \\overline{V} + 2 W \\omega \\rho \\overline{W}$"
],
"text/plain": [
"2*U*omega*rho*conjugate(U) + 2*V*omega*rho*conjugate(V) + 2*W*omega*rho*conjugate(W)"
"''"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sp.diff(Lag,om)"
"sp.diff(Lag,om)\n",
"\"\""
]
},
{
Expand Down Expand Up @@ -442,7 +506,67 @@
"metadata": {},
"outputs": [],
"source": [
"str(dU)"
"def replace_fortran_str1(text):\n",
" temp:str = text.replace(\" &\\n \",\"\")\n",
"\n",
" # find all sin\n",
" idx= [i for i in range(len(temp)) if temp.startswith('sin(', i)]\n",
" idx_cos = [i for i in range(len(temp)) if temp.startswith('cos(', i)] \n",
" idx.extend(idx_cos)\n",
" idx.sort()\n",
" idx1 = idx.copy()\n",
"\n",
" # get start/end with sin/cos\n",
" a = list(temp)\n",
" for i in range(len(idx)):\n",
" loc = 0\n",
" while a[idx[i] + loc] != ')':\n",
" loc += 1\n",
" idx1[i] = idx[i] + loc\n",
"\n",
" # add change line symbol\n",
" for i in range(80,len(a),72):\n",
" print(i,len(a))\n",
" # check if this location is in a sin/cos parenthesis\n",
" loc = -1\n",
" for j in range(len(idx)):\n",
" if i >= idx[j] and i <=idx1[j]:\n",
" loc = j\n",
" break\n",
" if loc == -1:\n",
" a[i] = a[i] + ' &\\n '\n",
" else:\n",
" a[idx1[loc]] = a[idx1[loc]] + ' &\\n '\n",
"\n",
" \n",
" # now replace sin/cos function with corresponding arrays\n",
" f_temp = ''.join(a)\n",
" f_temp = f_temp.replace(\"cos(theta0)\",\"costh0\")\n",
" f_temp = f_temp.replace(\"sin(theta0)\",\"sinth0\")\n",
" f_temp = f_temp.replace(\"cos(dphi)\",\"cosphi\")\n",
" f_temp = f_temp.replace(\"sin(dphi)\",\"sinphi\")\n",
"\n",
" return f_temp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = 'temp(:) = A*sin(dphi)*sin(theta0)**2*cos(dphi)*cos(theta0)**2 + C*sin( &\\n dphi)*sin(theta0)**2*cos(dphi)*cos(theta0)**2 - 2*F*sin(dphi)*sin &\\n (theta0)**2*cos(dphi)*cos(theta0)**2 - 4*L*sin(dphi)*sin(theta0) &\\n **2*cos(dphi)*cos(theta0)**2 + L*sin(dphi)*sin(theta0)**2*cos( &\\n dphi) - N*sin(dphi)*sin(theta0)**2*cos(dphi)'\n",
"temp = replace_fortran_str1(text)\n",
"print(temp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"temp"
]
},
{
Expand Down
Loading

0 comments on commit 8f579c6

Please sign in to comment.