Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/tax-day-experiment' into peak-fit
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamCorrao committed Apr 12, 2024
2 parents eadb120 + cf51bbd commit e0b9f2e
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 21 deletions.
15 changes: 9 additions & 6 deletions containers/mmm5-tax-day-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ services:
command: conda run -n GSASII --no-capture-output uvicorn bluesky_adaptive.server:app --host 0.0.0.0 --root-path /gsas-agent
environment:
- TILED_API_KEY=$TILED_API_KEY
- HTTPSERVER_API_KEY=$HTTPSERVER_API_KEY
- BS_AGENT_STARTUP_SCRIPT_PATH=/src/pdf-agents/pdf_agents/startup_scripts/mmm5-tax-day/gsas.py
volumes:
- type: bind
source: ../pdf_agents
target: /src/pdf-agents/pdf_agents
source: ../
target: /src/pdf-agents/
read_only: true
- type: bind
source: /etc/bluesky/kafka.yml
Expand Down Expand Up @@ -48,11 +49,12 @@ services:
command: conda run -n GSASII --no-capture-output uvicorn bluesky_adaptive.server:app --host 0.0.0.0 --root-path /kmeans-gsas-agent
environment:
- TILED_API_KEY=$TILED_API_KEY
- HTTPSERVER_API_KEY=$HTTPSERVER_API_KEY
- BS_AGENT_STARTUP_SCRIPT_PATH=/src/pdf-agents/pdf_agents/startup_scripts/mmm5-tax-day/kmeans-gsas.py
volumes:
- type: bind
source: ../pdf_agents
target: /src/pdf-agents/pdf_agents
source: ../
target: /src/pdf-agents/
read_only: true
- type: bind
source: /etc/bluesky/kafka.yml
Expand Down Expand Up @@ -105,10 +107,11 @@ services:
gsas-agent:
condition: service_started
gsas-ui:
condition: healthy
condition: service_healthy
kmeans-gsas:
condition: service_started
kmeans-gsas-ui:
condition: healthy
condition: service_healthy
restart: always


89 changes: 82 additions & 7 deletions pdf_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import ast
import copy
import os
import time as ttime
import uuid
from abc import ABC
from logging import getLogger
Expand All @@ -8,7 +11,7 @@
import numpy as np
import redis
import tiled
from bluesky_adaptive.agents.base import Agent, AgentConsumer
from bluesky_adaptive.agents.base import Agent, AgentConsumer, infer_data_keys
from bluesky_adaptive.agents.simple import SequentialAgentBase
from bluesky_kafka import Publisher
from bluesky_queueserver_api.http import REManagerAPI
Expand All @@ -31,6 +34,7 @@ def __init__(
roi: Optional[Tuple] = None,
norm_region: Optional[Tuple] = None,
offline=False,
metadata=None,
**kwargs,
):
if offline:
Expand Down Expand Up @@ -72,6 +76,8 @@ def __init__(
roi_key=self.roi_key,
roi=self.roi,
)
metadata = metadata or {}
md.update(metadata)
super().__init__(*args, metadata=md, **_default_kwargs)

def measurement_plan(self, point: ArrayLike) -> Tuple[str, List, Dict]:
Expand Down Expand Up @@ -99,8 +105,8 @@ def unpack_run(self, run) -> Tuple[Union[float, ArrayLike], Union[float, ArrayLi
if self.background is not None:
y = y - self.background[1]

ordinate = np.array(run.primary.data[self.roi_key]).flatten()
if self.norm_region is not None:
ordinate = np.array(run.primary.data[self.roi_key]).flatten()
idx_min = (
np.where(ordinate < self.norm_region[0])[0][-1]
if len(np.where(ordinate < self.norm_region[0])[0])
Expand Down Expand Up @@ -131,6 +137,8 @@ def unpack_run(self, run) -> Tuple[Union[float, ArrayLike], Union[float, ArrayLi
ordinate = np.array(run.primary.data[self.roi_key]).flatten()
idx_min = np.where(ordinate < self.roi[0])[0][-1] if len(np.where(ordinate < self.roi[0])[0]) else None
idx_max = np.where(ordinate > self.roi[1])[0][-1] if len(np.where(ordinate > self.roi[1])[0]) else None
else:
idx_min, idx_max = None, None

y = y[idx_min:idx_max]
self._ordinate = ordinate[idx_min:idx_max] # Update self oridnate. Should be constant unless roi changes.
Expand Down Expand Up @@ -274,7 +282,7 @@ def get_beamline_objects() -> dict:
config_file_path="/etc/bluesky/kafka.yml"
)
qs = REManagerAPI(http_server_uri=f"https://qserver.nsls2.bnl.gov/{beamline_tla}")
qs.set_authorization_key(api_key="yyyyy")
qs.set_authorization_key(api_key=os.getenv("HTTPSERVER_API_KEY", "zzzzz"))

kafka_consumer = AgentConsumer(
topics=[
Expand Down Expand Up @@ -325,6 +333,35 @@ def get_offline_objects() -> dict:
def trigger_condition(self, uid) -> bool:
return True

def close_and_restart(self, *, clear_tell_cache=False, retell_all=False, reason=""):
"""Utility for closing and restarting an agent with the same name.
This is primarily for methods that change the hyperparameters of an agent on the fly,
but in doing so may change the shape/nature of the agent document stream. This will
keep the documents consistent between hyperparameters as individual BlueskyRuns.
TODO: OVERRIDE FROM ADAPTIVE. MAKE PR TO FIX UPSTREAM.
Parameters
----------
clear_tell_cache : bool, optional
Clears the cache of data the agent has been told about, by default False.
This is useful for a clean slate.
retell_all : bool, optional
Resets the cache and tells the agent about all previous data, by default False.
This can be useful if the agent has not retained knowledge from previous tells.
reason : str, optional
Reason for closing and restarting the agent, to be recorded to logs, by default ""
"""
self.stop(reason=f"Close and Restart: {reason}")
self.kafka_consumer.closed = False
self._compose_descriptor_bundles = dict()
if clear_tell_cache:
self.tell_cache = list()
elif retell_all:
uids = copy.copy(self.tell_cache)
self.tell_cache = list()
self.tell_agent_by_uid(uids)
self.start()


class PDFSequentialAgent(PDFBaseAgent, SequentialAgentBase):
def __init__(
Expand Down Expand Up @@ -362,11 +399,49 @@ def __init__(self, *args, report_producer: Publisher, **kwargs):
self._report_producer = report_producer
super().__init__(*args, **kwargs)

def start(self, *args, **kwargs):
super().start(*args, **kwargs)
self._report_producer("start", self._compose_run_bundle.start_doc)

def stop(self, exit_status="success", reason=""):
logger.debug("Attempting agent stop.")
stop_doc = self._compose_run_bundle.compose_stop(exit_status=exit_status, reason=reason)
self.agent_catalog.v1.insert("stop", stop_doc)
self._report_producer("stop", stop_doc)
self.kafka_producer.flush()
self.kafka_consumer.stop()
logger.info(
f"Stopped agent with exit status {exit_status.upper()}"
f"{(' for reason: ' + reason) if reason else '.'}"
)

def _write_event(self, stream, doc, uid=None):
"""Add event to builder as event page, and publish to catalog
Taken from bluesky adaptive and modified to write to kafka as well as tiled.
TODO: OVERRIDE FROM ADAPTIVE. MAKE PR TO FIX UPSTREAM.
"""
if not doc:
logger.info(f"No doc presented to write_event for stream {stream}")
return
if stream not in self._compose_descriptor_bundles:
data_keys = infer_data_keys(doc)
self._compose_descriptor_bundles[stream] = self._compose_run_bundle.compose_descriptor(
name=stream, data_keys=data_keys
)
self.agent_catalog.v1.insert("descriptor", self._compose_descriptor_bundles[stream].descriptor_doc)
self._report_producer("descriptor", self._compose_descriptor_bundles[stream].descriptor_doc)

t = ttime.time()
event_doc = self._compose_descriptor_bundles[stream].compose_event(
data=doc, timestamps={k: t for k in doc}, uid=uid
)
self.agent_catalog.v1.insert("event", event_doc)
self._report_producer("event", event_doc)

return event_doc["uid"]

def generate_report(self, **kwargs):
doc = self.report(**kwargs)
uid = self._write_event("report", doc)
self._report_producer("report", doc)
logger.info(f"Generated report. Tiled: {uid}\n Kafka: {doc.get('uid', 'No UID')}")
super().generate_report(**kwargs)
self.close_and_restart(clear_tell_cache=False, retell_all=False, reason="Per-Run Subscribers")

@classmethod
Expand Down
14 changes: 8 additions & 6 deletions pdf_agents/gsas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
cif_paths: List[Union[str, Path]],
refinement_params: List[dict],
inst_param_path: Union[str, Path],
metadata: dict = None,
**kwargs,
):
self._cif_paths = cif_paths
Expand All @@ -57,7 +58,9 @@ def __init__(
self._recent_x = None
self._recent_y = None
self._recent_uid = None
super().__init__(**kwargs)
metadata = metadata or {}
metadata.update(refinement_params=self._refinement_params)
super().__init__(metadata=metadata, **kwargs)
self.report_on_tell = True

@property
Expand Down Expand Up @@ -176,15 +179,14 @@ def report(self) -> Dict[str, ArrayLike]:

return dict(
data_key=self.data_key,
roi_key=self.roi,
roi=self.roi,
norm_region=self.norm_region,
roi_key=self.roi_key if self.roi_key is not None else "",
roi=self.roi if self.roi is not None else "",
norm_region=self.norm_region if self.norm_region is not None else "",
observable_uid=self._recent_uid,
independent_variable=self._recent_x,
raw_independent_variable=self._recent_x,
observable=self._recent_y,
cif_paths=self.cif_paths,
inst_param_path=self.inst_param_path,
refinement_params=self.refinement_params,
gsas_rwps=np.array(gsas_rwps),
gsas_ycalcs=np.stack(gsas_ycalcs),
gsas_ydiffs=np.stack(gsas_ydiffs),
Expand Down
146 changes: 146 additions & 0 deletions pdf_agents/self_avoiding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import numpy as np
import scipy.ndimage as sndi
import matplotlib.pyplot as plt
import matplotlib as mpl
from typing import Tuple


def compute_prob(a, *, sigma=20, gamma=2, bonus=None):
if (a != 0).all():
dist = np.ones_like(a)
else:
dist = sndi.distance_transform_edt(a)
if gamma != 1:
dist **= gamma

dist = 1-np.exp(-dist/ (2*sigma*sigma))

if bonus is not None:
dist += bonus

dist[np.isnan(a)] = 0
dist[dist<=0] = 0
dist /= dist.sum()

return dist


def pick_next_point(a, N=1, *, gamma=1, sigma=20, bonus=None):
dist = compute_prob(a, gamma=gamma, sigma=sigma, bonus=bonus)
cdf = np.cumsum(dist.ravel())
r = np.random.rand(N)
idx = np.searchsorted(cdf, r, side='right')
return np.unravel_index(idx, a.shape)


def show_dan(a, extent=None, *, gamma=1, sigma=20, bonus=None):
fig, (ax1, ax2) = plt.subplots(1, 2, layout="constrained", sharex=True, sharey=True, figsize=(8, 4))
ax1.imshow(a, origin="lower", extent=extent)
d = compute_prob(a, gamma=gamma, sigma=sigma, bonus=bonus)
cmap = mpl.colormaps['viridis']
cmap.set_under('w')
im = ax2.imshow(d, origin="lower", vmin=1e-25, extent=extent, interpolation_stage='rgba', cmap=cmap)

fig.colorbar(im, extend='min')

return fig, (ax1, ax2)


class WaferManager:
def __init__(self, radius: float, resolution: float, center: Tuple[float, float]):
"""
Parameters
----------
radius : float
radius of wafer.
In same units as resolution and center.
resolution : float
scale to quantize the space on.
In same units as radius and center
center : Tuple[float, float]
location of center of wafer in motor coordiates
In same units as radius and resoultion
"""
N = int(np.ceil(2 * radius / resolution))

self._measured = np.ones((N, N))
self._center = np.array(center)
# 0, 0 in the mask is the lower outboard
self._ll = self._center - radius
self._resolution = resolution
self._radius = radius
self._mask = ~(np.hypot(*np.ogrid[-N//2:N//2, -N//2:N//2]) < N//2)

self._measured[self._mask] = np.nan
self.gamma = 2
self.sigma = 20
self.extra = np.zeros_like(self._measured)
self.runs = {}

def xy_to_index(self, xy):
xy = np.atleast_1d(xy)
rel_xy = xy - self._ll
return tuple((rel_xy[::-1] // self._resolution).astype("int"))

def add_measurement(self, xy, hdr=None):

indx = self.xy_to_index(xy)
self._measured[indx] = 0
self.runs[indx] = hdr

def add_header(self, hdr):
xy = self._extract_coords(hdr)
self.add_measurement(xy, hdr=hdr)

@staticmethod
def _extract_coords(hdr):
raise Exception("WRITE ME")

def debug_vis(self):
center_x, center_y = self._center

return show_dan(
self._measured,
extent=[
center_x - self._radius,
center_x + self._radius,
center_y - self._radius,
center_y + self._radius,
],
gamma=self.gamma,
sigma=self.sigma,
bonus=self.extra
)

def pick_next_point(self, N=1):
y, x = pick_next_point(self._measured, N=N, gamma=self.gamma, sigma=self.sigma, bonus=self.extra)
return (np.stack([x, y]).T * self._resolution + self._ll).T

def manual_demo():
wm = WaferManager(40, 0.4, (-130, 50))
wm.extra[50:100, 50:100] = -100
wm.add_measurement((-110, 50))
fig, (ax1, ax2) = wm.debug_vis()
ax2.plot(*wm.pick_next_point(N=100), 'o')

for j in range(15):
(x,), (y,) = wm.pick_next_point(N=1)
wm.add_measurement((x, y))
fig, (ax1, ax2) = wm.debug_vis()

ax2.plot(*wm.pick_next_point(N=100), 'o')

manual_demo()

wm = WaferManager(40, 0.4, (-130, 50))

def demo_plan(wm):

for j in range(25):
(x,), (y,) = wm.pick_next_point(N=1)
uid = (yield from move_and_measure(x, y))
hdr = db[uid]
wm.add_header(hdr)
2 changes: 1 addition & 1 deletion pdf_agents/startup_scripts/mmm5-tax-day/gsas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

report_producer = RefinementAgent.get_default_producer()


agent = RefinementAgent(
# GSAS Args
cif_paths=["/src/pdf-agents/assets/fcc.cif"],
Expand Down
Loading

0 comments on commit e0b9f2e

Please sign in to comment.