Skip to content

Commit

Permalink
Merge pull request #13 from bonsai-rx/lds-forecast-mashup-visualizers
Browse files Browse the repository at this point in the history
Added forecast visualizers
  • Loading branch information
glopesdev authored Jun 4, 2024
2 parents 325a30a + 417bf34 commit aeb9c45
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/Bonsai.ML.Visualizers/Bonsai.ML.Visualizers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
<ItemGroup>
<PackageReference Include="Bonsai.Core" Version="2.8.1" />
<PackageReference Include="Bonsai.Design" Version="2.8.0" />
<PackageReference Include="Bonsai.Vision.Design" Version="2.8.1" />
<PackageReference Include="MathNet.Numerics" Version="5.0.0" />
<PackageReference Include="OxyPlot.Core" Version="2.1.2" />
<PackageReference Include="OxyPlot.WindowsForms" Version="2.1.2" />
</ItemGroup>
Expand Down
91 changes: 91 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastImageOverlay.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using Bonsai.Design;
using Bonsai.Vision.Design;
using Bonsai;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using System;
using System.Collections.Generic;
using OpenCV.Net;
using MathNet.Numerics.LinearAlgebra;
using OxyPlot;

[assembly: TypeVisualizer(typeof(ForecastImageOverlay), Target = typeof(MashupSource<ImageMashupVisualizer, ForecastVisualizer>))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a mashup visualizer to display the forecast of a Kalman Filter kinematics model overtime of an ImageMashupVisualizer.
/// </summary>
public class ForecastImageOverlay : DialogTypeVisualizer
{
private ImageMashupVisualizer visualizer;
private IplImage overlay;

/// <inheritdoc/>
public override void Show(object value)
{

var image = visualizer.VisualizerImage;
Size size = new Size(image.Width, image.Height);
IplDepth depth = image.Depth;
int channels = image.Channels;

overlay = new IplImage(size, depth, channels);
var alpha = 0.1;

Forecast forecast = (Forecast)value;
List<ForecastResult> forecastResults = forecast.ForecastResults;

for (int i = 0; i < forecastResults.Count; i++)
{
var forecastResult = forecastResults[i];
var kinematicState = forecastResult.KinematicState;

double xMean = kinematicState.Position.X.Mean;
double yMean = kinematicState.Position.Y.Mean;

Point center = new Point((int)Math.Round(xMean), (int)Math.Round(yMean));

double xVar = kinematicState.Position.X.Variance;
double yVar = kinematicState.Position.Y.Variance;
double xyCov = kinematicState.Position.Covariance;

var covariance = Matrix<double>.Build.DenseOfArray(new double[,] {
{ xVar, xyCov },
{ xyCov, yVar }
});

var evd = covariance.Evd();
var evals = evd.EigenValues.Real();
var evecs = evd.EigenVectors;

double angle = Math.Atan2(evecs[1, 0], evecs[0, 0]) * 180 / Math.PI;

Size axes = new Size
{
Width = (int)(2 * Math.Sqrt(evals[0])),
Height = (int)(2 * Math.Sqrt(evals[1]))
};

OxyColor color = OxyColors.Yellow;

CV.Ellipse(overlay, center, axes, angle, 0, 360, new Scalar(color.B, color.G, color.R, color.A), -1);
}

CV.AddWeighted(image, 1 - alpha, overlay, alpha, 1, image);
overlay.SetZero();
}

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
visualizer = (ImageMashupVisualizer)provider.GetService(typeof(MashupVisualizer));
}

/// <inheritdoc/>
public override void Unload()
{
overlay.Dispose();
}
}
}
116 changes: 116 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastPlotOverlay.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
using Bonsai.Design;
using Bonsai;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using System;
using System.Collections.Generic;
using OxyPlot.Series;
using OxyPlot;

[assembly: TypeVisualizer(typeof(ForecastPlotOverlay), Target = typeof(MashupSource<KinematicStateVisualizer, ForecastVisualizer>))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a mashup visualizer to display the forecast of a Kalman Filter kinematics model overtime of a KinematicStateVisualizer.
/// </summary>
public class ForecastPlotOverlay : DialogTypeVisualizer
{
private List<LineSeries> lineSeriesList = new();

private List<AreaSeries> areaSeriesList = new();

private KinematicStateVisualizer visualizer;

/// <inheritdoc/>
public override void Show(object value)
{
var time = DateTime.Now;
Forecast forecast = (Forecast)value;
var componentVisualizers = visualizer.ComponentVisualizers;

for (int i = 0; i < componentVisualizers.Count; i++)
{
var plot = componentVisualizers[i].Plot;
var lineSeries = lineSeriesList[i];
var areaSeries = areaSeriesList[i];

plot.ResetLineSeries(lineSeries);
plot.ResetAreaSeries(areaSeries);

DateTime forecastTime = time;

for (int j = 0; j < forecast.ForecastResults.Count; j++)
{
var forecastResult = forecast.ForecastResults[j];
var kinematicState = forecastResult.KinematicState;
forecastTime = time + forecastResult.Timestep;

StateComponent[] stateComponents = new StateComponent[] {kinematicState.Position.X, kinematicState.Position.Y, kinematicState.Velocity.X, kinematicState.Velocity.Y, kinematicState.Acceleration.X, kinematicState.Acceleration.Y};

AddStateComponentDataToSeries(plot, stateComponents[i], lineSeries, areaSeries, forecastTime);

}

plot.SetAxes(minTime: forecastTime.AddSeconds(-plot.Capacity), maxTime: forecastTime);
}
}

private void AddStateComponentDataToSeries(TimeSeriesOxyPlotBase plot, StateComponent stateComponent, LineSeries lineSeries, AreaSeries areaSeries, DateTime time)
{
double mean = stateComponent.Mean;
double variance = stateComponent.Variance;

plot.AddToLineSeries(
lineSeries: lineSeries,
time: time,
value: mean
);

plot.AddToAreaSeries(
areaSeries: areaSeries,
time: time,
value1: mean + variance,
value2: mean - variance
);
}

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
if (lineSeriesList.Count > 0)
{
lineSeriesList.Clear();
lineSeriesList = new();
}

if (areaSeriesList.Count > 0)
{
areaSeriesList.Clear();
areaSeriesList = new();
}

var service = provider.GetService(typeof(MashupVisualizer));
visualizer = (KinematicStateVisualizer)service;
var componentVisualizers = visualizer.ComponentVisualizers;

for (int i = 0; i < componentVisualizers.Count; i++)
{
var lineSeries = componentVisualizers[i].Plot.AddNewLineSeries($"Forecast {visualizer.Labels[i]} Mean", color: OxyColors.Yellow);
var areaSeries = componentVisualizers[i].Plot.AddNewAreaSeries($"Forecast {visualizer.Labels[i]} Variance", color: OxyColors.Yellow, opacity: 50);

componentVisualizers[i].Plot.ResetLineSeries(lineSeries);
componentVisualizers[i].Plot.ResetAreaSeries(areaSeries);

lineSeriesList.Add(lineSeries);
areaSeriesList.Add(areaSeries);
}
}

/// <inheritdoc/>
public override void Unload()
{
}
}
}
133 changes: 133 additions & 0 deletions src/Bonsai.ML.Visualizers/ForecastVisualizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using System;
using System.Windows.Forms;
using System.Collections.Generic;
using Bonsai;
using Bonsai.Design;
using Bonsai.ML.Visualizers;
using Bonsai.ML.LinearDynamicalSystems.Kinematics;
using OxyPlot;
using System.Reactive;
using System.Linq;
using System.Reactive.Linq;

[assembly: TypeVisualizer(typeof(ForecastVisualizer), Target = typeof(Forecast))]

namespace Bonsai.ML.Visualizers
{
/// <summary>
/// Provides a type visualizer to display the forecast of a Kalman Filter kinematics model.
/// </summary>
public class ForecastVisualizer : BufferedVisualizer
{

private int rowCount = 3;
private int columnCount = 2;
private string[] labels = new string[] {
"Forecast Position X",
"Forecast Position Y",
"Forecast Velocity X",
"Forecast Velocity Y",
"Forecast Acceleration X",
"Forecast Acceleration Y"
};

private List<StateComponentVisualizer> componentVisualizers = new();
private TableLayoutPanel container;

/// <inheritdoc/>
public override void Load(IServiceProvider provider)
{
container = new TableLayoutPanel
{
ColumnCount = columnCount,
RowCount = rowCount,
Dock = DockStyle.Fill
};

for (int i = 0; i < container.RowCount; i++)
{
container.RowStyles.Add(new RowStyle(SizeType.Percent, 100f / rowCount));
}

for (int i = 0; i < container.ColumnCount; i++)
{
container.ColumnStyles.Add(new ColumnStyle(SizeType.Percent, 100f / columnCount));
}

for (int i = 0 ; i < rowCount; i++)
{
for (int j = 0; j < columnCount; j++)
{
var StateComponentVisualizer = new StateComponentVisualizer() {
Label = labels[i * columnCount + j],
LineSeriesColor = OxyColors.Yellow,
AreaSeriesColor = OxyColors.Yellow
};
StateComponentVisualizer.Load(provider);
container.Controls.Add(StateComponentVisualizer.Plot, j, i);
componentVisualizers.Add(StateComponentVisualizer);
}
}

var visualizerService = (IDialogTypeVisualizerService)provider.GetService(typeof(IDialogTypeVisualizerService));

if (visualizerService != null)
{
visualizerService.AddControl(container);
}
}

/// <inheritdoc/>
public override void Show(object value)
{
}

/// <inheritdoc/>
protected override void ShowBuffer(IList<Timestamped<object>> values)
{
if (values.Count == 0) return;
var latestForecast = values.Last();
var timestamp = latestForecast.Timestamp;
var forecast = (Forecast)latestForecast.Value;
var futureTime = timestamp;

List<Timestamped<object>> positionX = new();
List<Timestamped<object>> positionY = new();
List<Timestamped<object>> velocityX = new();
List<Timestamped<object>> velocityY = new();
List<Timestamped<object>> accelerationX = new();
List<Timestamped<object>> accelerationY = new();

foreach (var forecastResult in forecast.ForecastResults)
{
futureTime = timestamp + forecastResult.Timestep;
positionX.Add(new Timestamped<object>(forecastResult.KinematicState.Position.X, futureTime));
positionY.Add(new Timestamped<object>(forecastResult.KinematicState.Position.Y, futureTime));
velocityX.Add(new Timestamped<object>(forecastResult.KinematicState.Velocity.X, futureTime));
velocityY.Add(new Timestamped<object>(forecastResult.KinematicState.Velocity.Y, futureTime));
accelerationX.Add(new Timestamped<object>(forecastResult.KinematicState.Acceleration.X, futureTime));
accelerationY.Add(new Timestamped<object>(forecastResult.KinematicState.Acceleration.Y, futureTime));
}

var dataList = new List<List<Timestamped<object>>>() { positionX, positionY, velocityX, velocityY, accelerationX, accelerationY };

var zippedData = dataList.Zip(componentVisualizers, (data, visualizer) => new { Data = data, Visualizer = visualizer });

foreach (var item in zippedData)
{
item.Visualizer.Plot.ResetLineSeries(item.Visualizer.LineSeries);
item.Visualizer.Plot.ResetAreaSeries(item.Visualizer.AreaSeries);
item.Visualizer.ShowDataBuffer(item.Data);
item.Visualizer.Plot.SetAxes(minTime: timestamp.DateTime, maxTime: futureTime.DateTime);
}
}

/// <inheritdoc/>
public override void Unload()
{
foreach (var componentVisualizer in componentVisualizers) componentVisualizer.Unload();
if (componentVisualizers.Count > 0) componentVisualizers.Clear();
if (!container.IsDisposed) container.Dispose();
}
}
}

0 comments on commit aeb9c45

Please sign in to comment.