Skip to content

Commit

Permalink
add examples and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidtdominik committed Feb 11, 2024
1 parent 5ca1747 commit 2da55e9
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 391 deletions.
199 changes: 168 additions & 31 deletions Untitled-1.ipynb → example.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,150 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pip install -q git+https://github.com/cpgoodri/jax_transformations3d.git"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import jax\n",
"from jax import jit, vmap\n",
"from functools import partial\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import jax_transformations3d as jts\n",
"import colorsys\n",
"import numpy as np\n",
"import math\n",
"import time\n",
"from tqdm.auto import tqdm, trange\n",
"\n",
"from raytrace import *"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
"# bounce count, time\n",
"# 4, 31,\n",
"# 8, 61.3, 71\n",
"# 16, 121-122\n",
"\n",
"sphere_pos, sphere_radius, mat_color, em_color, em_strength, mat = stack_dict_list(\n",
" spheres\n",
")\n",
"\n",
"key = jax.random.PRNGKey(0)\n",
"result_img = jnp.zeros((res_y, res_x, 3))\n",
"\n",
"# this should be before every ray trace, but for now keep it here\n",
"x_offset, y_offset = (jax.random.uniform(subkey, (2,)) - 0.5) * 0.005\n",
"ray_pos, ray_dirs = get_init(\n",
" res_x, res_y, x_persp, y_persp, camera_persp, x_offset, y_offset\n",
")\n",
"#\n",
"\n",
"result_img.block_until_ready()\n",
"t0 = time.time()\n",
"\n",
"k = 10\n",
"for i in range(k):\n",
" key, subkey = jax.random.split(key, 2)\n",
" key_grid = jax.random.split(subkey, res_x * res_y).reshape((res_x, res_y, -1))\n",
"\n",
" result_img += full_ray_trace(\n",
" ray_pos,\n",
" ray_dirs,\n",
" key_grid,\n",
" sphere_pos,\n",
" sphere_radius,\n",
" mat_color,\n",
" em_color,\n",
" em_strength,\n",
" mat,\n",
" )\n",
"\n",
"result_img.block_until_ready()\n",
"print((time.time() - t0) * 1000)\n",
"\n",
"plt.imshow(result_img / jnp.quantile(result_img.flatten(), 0.95))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"base_res = 256\n",
"x_persp, y_persp = 1.5 * 3, 1 * 3\n",
"res_x, res_y = int(base_res * x_persp), int(base_res * y_persp)\n",
"camera_persp = 12\n",
"\n",
"\n",
"n = 5\n",
"spheres = [\n",
" {\n",
" \"pos\": [5, -5 + l * 10, 0],\n",
" \"radius\": 1,\n",
" \"mat_color\": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),\n",
" \"em_color\": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),\n",
" \"em_strength\": 0,\n",
" \"mat\": max(l, 0.01),\n",
" }\n",
" for l in jnp.linspace(0, 1, n)\n",
"] + [\n",
" # {'pos': [10000+15, 0, 0], 'radius': 10000, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0.01, 'mat': 0.5},\n",
" # ground\n",
" {\n",
" \"pos\": [5, 0, 40000],\n",
" \"radius\": 40000 - 1,\n",
" \"mat_color\": [1, 1, 1],\n",
" \"em_color\": [1, 1, 1],\n",
" \"em_strength\": 0,\n",
" \"mat\": 1,\n",
" },\n",
" # back wall\n",
" # {'pos': [40000, 0, 0], 'radius': 40000-100, 'mat_color': [0.1, .1, 0.1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 1},\n",
" # ceiling light\n",
" {\n",
" \"pos\": [5, 0, -40000],\n",
" \"radius\": 40000 - 5000,\n",
" \"mat_color\": [1, 1, 1],\n",
" \"em_color\": [1, 1, 1],\n",
" \"em_strength\": 0.000001,\n",
" \"mat\": 1,\n",
" },\n",
" # {'pos': [10, 5, -5], 'radius': 4, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 0},\n",
" # {'pos': [5-math.sin(math.pi*i*2)*10, math.cos(math.pi*i*2)*10, -4], 'radius': 1, 'mat_color': [0, 0, 0], 'em_color': [1, 1, 1], 'em_strength': 3, 'mat': 1},\n",
" # {'pos': [6, -10, 0], 'radius': 0.5, 'mat_color': [0, 0, 0], 'em_color': [0, 0.5, 1], 'em_strength': 40, 'mat': 1},\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -34,7 +179,7 @@
" mat,\n",
" )\n",
"\n",
" return result_img\n"
" return result_img"
]
},
{
Expand Down Expand Up @@ -86,7 +231,7 @@
" # result_img = result_img.at[:400].set(0)\n",
" # plt.imshow(result_img/result_img[300:].mean(axis=-1).max()*16, interpolation='none')\n",
" # plt.show()\n",
" images.append(result_img)\n"
" images.append(result_img)"
]
},
{
Expand Down Expand Up @@ -135,7 +280,7 @@
"plt.figure(figsize=(15, 15))\n",
"# result_img = result_img.at[:400].set(0)\n",
"plt.imshow(result_img / result_img[300:].mean(axis=-1).max(), interpolation=\"none\")\n",
"plt.show()\n"
"plt.show()"
]
},
{
Expand All @@ -152,18 +297,11 @@
"\n",
"cpus = jax.devices(\"cpu\")\n",
"results = jnp.stack([jax.device_put(r, cpus[0]) for r in results])\n",
"\n",
"\n",
"\n",
"\n",
"images_procs = jnp.stack(images)\n",
"\n",
"\n",
"max_v = jnp.quantile(images_procs[:, 250:].flatten(), 0.95)\n",
"\n",
"images_procs = (images_procs / max_v * 255).astype(jnp.uint8)\n",
"\n",
"imageio.mimsave(\"video.mp4\", list(images_procs) * 2, fps=20)\n"
"imageio.mimsave(\"video.mp4\", list(images_procs) * 2, fps=20)"
]
},
{
Expand Down Expand Up @@ -208,7 +346,7 @@
" \"em_strength\": 40,\n",
" \"mat\": 1,\n",
" },\n",
"]\n"
"]"
]
},
{
Expand All @@ -228,30 +366,29 @@
" normal_result = jnp.take_along_axis(normal, indices=closest_hit[..., None, None], axis=-1).squeeze()\n",
"\n",
" return did_hit_result, dst_result, hit_point_result, normal_result\n",
"\"\"\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"t = time.time()\n",
"dist, hit_point, normal = ray_trace_iter(ray_pos, ray_dirs, sphere_pos, sphere_radius)\n",
"print(dist.sum())\n",
"print(time.time() - t)\n"
"\"\"\""
]
}
],
"metadata": {
"language_info": {
"name": "python"
"kernelspec": {
"display_name": "jax-tf",
"language": "python",
"name": "jax-tf"
},
"orig_nbformat": 4
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
131 changes: 131 additions & 0 deletions raytrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from functools import partial

import jax
import jax.numpy as jnp
from jax import jit, vmap


def ray_intersect(ray_origin, ray_dir, sphere_center, sphere_radius):
# ray_dir should be normalized
offset_ray_origin = ray_origin - sphere_center
a = jnp.dot(ray_dir, ray_dir)
b = 2 * jnp.dot(offset_ray_origin, ray_dir)
c = jnp.dot(offset_ray_origin, offset_ray_origin) - sphere_radius**2

discriminant = b**2 - 4 * a * c
dist = (-b - jnp.sqrt(discriminant)) / (2 * a)

# sphere was hit if (discriminant >= 0) & (dist >= 0)
# if discriminant < 0, then dist is nan already, if dist < 0,
# meaning hitpoint is against ray direction, set dist to nan
dist = jnp.where(dist < 0, jnp.nan, dist)

return dist


def ray_intersect_target_batch(ray_origin, ray_dir, sphere_center, sphere_radius):
dist = ray_intersect(ray_origin, ray_dir, sphere_center, sphere_radius)
closest_hit = jnp.nanargmin(dist)
dist = dist[closest_hit]

return dist, closest_hit


@jit
@partial(vmap, in_axes=(1, 1, 1, None, None, None, None, None, None))
@partial(vmap, in_axes=(0, 0, 0, None, None, None, None, None, None))
def full_ray_trace(
ray_origin,
ray_dir,
key,
sphere_center,
sphere_radius,
mat_color,
em_color,
em_strength,
mat,
):
def ray_trace_single_hit(args):
inc_light, ray_color, ray_origin, ray_dir, key, i, done = args
dist, closest_hit = ray_intersect_target_batch(
ray_origin, ray_dir, sphere_center, sphere_radius
)
did_hit = ~jnp.isnan(dist) # or: closest_hit != -1
# done = done | (did_hit & (em_strength[closest_hit] == 1) & (jnp.arange(10)[i] == 0))
# did_hit = did_hit & ~done

hit_point = ray_origin + ray_dir * dist
normal = hit_point - sphere_center[closest_hit]
normal = normal / jnp.linalg.norm(normal)

emitted_light = em_color[closest_hit] * em_strength[closest_hit]
light_strength = jnp.dot(normal, -ray_dir)
light_strength = jnp.where(jnp.isnan(light_strength), 0, light_strength)
inc_light += did_hit * (emitted_light * ray_color)
# *light_strength
ray_color = (
did_hit * ray_color * mat_color[closest_hit] + (~did_hit) * ray_color
)

key, subkey = jax.random.split(key, 2)
random_dir = jax.random.normal(subkey, (3,))
random_dir = random_dir / jnp.linalg.norm(random_dir)
diffuse_reflect = random_dir * jnp.sign(jnp.dot(random_dir, normal))
diffuse_reflect = diffuse_reflect / jnp.linalg.norm(diffuse_reflect)

specular_reflect = (
ray_dir - 2 * jnp.dot(ray_dir, normal) * normal
) # maybe should be -raydir too?
specular_reflect = specular_reflect / jnp.linalg.norm(specular_reflect)

alpha = mat[closest_hit]
reflect_dir = alpha * diffuse_reflect + (1 - alpha) * specular_reflect

return inc_light, ray_color, hit_point, reflect_dir, key, i + 1, done | ~did_hit

inc_light = jnp.zeros((3,))
ray_color = jnp.ones((3,))

# inc_light, ray_color, _, _, _ = ray_trace_single_hit(inc_light, ray_color, ray_origin, ray_dir, key)

def cond_fun(args):
i = args[-2]
done = args[-1]
return (i < 8) | done

inc_light, ray_color, _, _, _, _, _ = jax.lax.while_loop(
cond_fun,
ray_trace_single_hit,
(inc_light, ray_color, ray_origin, ray_dir, key, 0, False),
)
return inc_light


# coordinates are (x, y, z)
# x is forward backward
# y is left right
# z is up down


def stack_dict_list(l):
keys = l[0].keys()
return [jnp.stack([jnp.array(elm[k]).astype(float) for elm in l]) for k in keys]


# @jit
def get_init(res_x, res_y, x_persp, y_persp, camera_persp, x_offset, y_offset):
camera_pos = jnp.array([-camera_persp, 0, 0]) # from above: set last to -2

focal_plane = jnp.zeros((res_x, res_y, 3))
x_grid, y_grid = jnp.meshgrid(
jnp.linspace(-x_persp, x_persp, res_x), jnp.linspace(-y_persp, y_persp, res_y)
)
focal_plane = focal_plane.at[:, :, 1].set(x_grid.T + x_offset)
focal_plane = focal_plane.at[:, :, 2].set(y_grid.T + y_offset)
ray_dirs = focal_plane - camera_pos
ray_dirs = ray_dirs / jnp.linalg.norm(ray_dirs, axis=-1, keepdims=True)

ray_origin = jnp.empty((res_x, res_y, 3))
ray_origin = ray_origin.at[:, :].set(camera_pos)

return ray_origin, ray_dirs
Loading

0 comments on commit 2da55e9

Please sign in to comment.