Skip to content

Commit

Permalink
Merge in changes from prime/main fixing conflicts.
Browse files Browse the repository at this point in the history
  • Loading branch information
hsorby committed Feb 3, 2025
2 parents ce10d8f + 5429621 commit f1abeaa
Show file tree
Hide file tree
Showing 6 changed files with 751 additions and 9 deletions.
51 changes: 51 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import io
import os

from setuptools import setup, find_packages

SETUP_DIR = os.path.dirname(os.path.abspath(__file__))


# List all of your Python package dependencies in the
# requirements.txt file

def readfile(filename, split=False):
with io.open(filename, encoding="utf-8") as stream:
if split:
return stream.read().split("\n")
return stream.read()


readme = readfile("README.rst", split=True)
# For requirements not hosted on PyPi place listings
# into the 'requirements.txt' file.
requires = [
# minimal requirements listing
"cmlibs.maths >= 0.3",
"cmlibs.utils >= 0.10",
"cmlibs.zinc >= 4.0"
]
readme.extend(['', 'License', '=======', '', '::', ''])
source_license = readfile("LICENSE")

setup(
name="scaffoldfitter",
version="0.10.0",
description="Scaffold/model geometric fitting library using Zinc.",
long_description="\n".join(readme) + source_license,
long_description_content_type="text/x-rst",
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Medical Science Apps."
],
author="Auckland Bioengineering Institute",
author_email="[email protected]",
url="https://github.com/ABI-Software/scaffoldfitter",
license="Apache Software License",
packages=find_packages("src"),
package_dir={"": "src"},
include_package_data=True,
zip_safe=False,
install_requires=requires,
)
29 changes: 22 additions & 7 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,6 @@ def _defineCommonMeshFields(self):
self._strainPenaltyField.assignReal(fieldcache, zeroValues)
self._curvaturePenaltyField.assignReal(fieldcache, zeroValues)
element = elemIter.next()
self._fieldmodule.endChange()
self._fieldmodule.beginChange()

def getStrainPenaltyField(self):
return self._strainPenaltyField
Expand Down Expand Up @@ -508,7 +506,8 @@ def run(self, endStep=None, modelFileNameStem=None, reorder=False):
endStep = self._fitterSteps[-1]
endIndex = self._fitterSteps.index(endStep)
# reload only if necessary
if endStep.hasRun() and (endIndex < (len(self._fitterSteps) - 1)) and self._fitterSteps[endIndex + 1].hasRun() or reorder:
if (endStep.hasRun() and (endIndex < (len(self._fitterSteps) - 1)) and self._fitterSteps[endIndex + 1].hasRun()
or reorder):
# re-load to get back to current state
self.load()
for index in range(1, endIndex + 1):
Expand All @@ -527,7 +526,7 @@ def getDataCoordinatesField(self):
return self._dataCoordinatesField

def setDataCoordinatesField(self, dataCoordinatesField: Field):
if dataCoordinatesField == self._dataCoordinatesField:
if (self._dataCoordinatesField is not None) and (dataCoordinatesField == self._dataCoordinatesField):
return
finiteElementField = dataCoordinatesField.castFiniteElement()
assert finiteElementField.isValid() and (finiteElementField.getNumberOfComponents() == 3)
Expand Down Expand Up @@ -581,7 +580,7 @@ def setMarkerGroup(self, markerGroup: Field):
self._markerDataNameField = None
self._markerDataLocationGroupField = None
self._markerDataLocationGroup = None
if not markerGroup:
if not (markerGroup and markerGroup.isValid()):
return
fieldGroup = markerGroup.castGroup()
assert fieldGroup.isValid()
Expand Down Expand Up @@ -1011,7 +1010,7 @@ def getModelReferenceCoordinatesField(self):
return self._modelReferenceCoordinatesField

def setModelCoordinatesField(self, modelCoordinatesField: Field):
if modelCoordinatesField == self._modelCoordinatesField:
if (self._modelCoordinatesField is not None) and (modelCoordinatesField == self._modelCoordinatesField):
return
finiteElementField = modelCoordinatesField.castFiniteElement()
mesh = self.getHighestDimensionMesh()
Expand Down Expand Up @@ -1205,6 +1204,9 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup,
sizeBefore = dataProjectionNodesetGroup.getSize()
dataCoordinates = self._dataCoordinatesField
dataProportion = activeFitterStepConfig.getGroupDataProportion(groupName)[0]
outlierLength = activeFitterStepConfig.getGroupOutlierLength(groupName)[0]
maximumProjectionLength = 0.0
dataProjectionLengths = [] # For relative outliers: list of (data identifier, projection length)
centralProjection = activeFitterStepConfig.getGroupCentralProjection(groupName)[0]
if centralProjection:
# use centre of bounding box as middle of data; previous use of mean was affected by uneven density
Expand Down Expand Up @@ -1255,8 +1257,21 @@ def calculateGroupDataProjections(self, fieldcache, group, dataGroup, meshGroup,
result = meshLocation.assignMeshLocation(fieldcache, element, xi)
assert result == RESULT_OK, \
"Error: Failed to assign data projection mesh location for group " + groupName
dataProjectionNodesetGroup.addNode(node)
result, projectionLength = self._dataErrorField.evaluateReal(fieldcache, 1)
if projectionLength > maximumProjectionLength:
maximumProjectionLength = projectionLength
if outlierLength < 0.0:
# store and filter once we know the maximum
dataProjectionLengths.append((node.getIdentifier(), projectionLength))
if (outlierLength <= 0.0) or (projectionLength <= outlierLength):
dataProjectionNodesetGroup.addNode(node)
node = nodeIter.next()
if outlierLength < 0.0:
relativeOutlierLength = (1.0 + outlierLength) * maximumProjectionLength
for nodeIdentifier, projectionLength in dataProjectionLengths:
if projectionLength > relativeOutlierLength:
node = dataGroup.findNodeByIdentifier(nodeIdentifier)
dataProjectionNodesetGroup.removeNode(node)
pointsProjected = dataProjectionNodesetGroup.getSize() - sizeBefore
if pointsProjected < dataGroup.getSize():
if self.getDiagnosticLevel() > 0:
Expand Down
48 changes: 46 additions & 2 deletions src/scaffoldfitter/fitterstepconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class FitterStepConfig(FitterStep):
_jsonTypeId = "_FitterStepConfig"
_centralProjectionToken = "centralProjection"
_dataProportionToken = "dataProportion"
_outlierLengthToken = "outlierLength"

def __init__(self):
super(FitterStepConfig, self).__init__()
Expand Down Expand Up @@ -62,7 +63,7 @@ def setGroupCentralProjection(self, groupName, centralProjection):
This can help fit groups which start well away from their targets.
:param groupName: Exact model group name, or None for default group.
:param centralProjection: Boolean True/False or None to reset to global
default. Function ensures value is valid.
default (False). Function ensures value is valid.
"""
if centralProjection is not None:
if not isinstance(centralProjection, bool):
Expand Down Expand Up @@ -96,7 +97,7 @@ def setGroupDataProportion(self, groupName, proportion):
global default.
:param groupName: Exact model group name, or None for default group.
:param proportion: Float valued proportion from 0.0 (0%) to 1.0 (100%),
or None to reset to global default. Function ensures value is valid.
or None to reset to global default (1.0). Function ensures value is valid.
"""
if proportion is not None:
if not isinstance(proportion, float):
Expand All @@ -107,6 +108,49 @@ def setGroupDataProportion(self, groupName, proportion):
proportion = 1.0
self.setGroupSetting(groupName, self._dataProportionToken, proportion)

def clearGroupOutlierLength(self, groupName):
"""
Clear local group outlier length so fall back to last config or global default.
:param groupName: Exact model group name, or None for default group.
"""
self.clearGroupSetting(groupName, self._outlierLengthToken)

def getGroupOutlierLength(self, groupName):
"""
Get relative or absolute length of data projections above which data points are treated as
outliers and not included in the fit, plus flags indicating where it has been set.
Values from -1.0 up to < 0.0 are negative proportions of maximum data projection to exclude,
e.g. -0.1 excludes data points within 10% of the maximum data projection.
Value 0.0 disables outlier filtering and includes all data (subject to other filters).
Values > 0.0 gives absolute projection length above which data points are excluded.
If not set or inherited, gets value from default group.
:param groupName: Exact model group name, or None for default group.
:return: OutlierLength, setLocally, inheritable.
Absolute outlier length > 0.0, or 0.0 to include all data (default).
The second return value is True if the value is set locally to a value
or None if reset locally.
The third return value is True if a previous config has set the value.
"""
return self.getGroupSetting(groupName, self._outlierLengthToken, 0.0)

def setGroupOutlierLength(self, groupName, outlierLength):
"""
Set relative or absolute length of data projections above which data points are treated as
outliers and not included in the fit, plus flags indicating where it has been set, or
reset to global default.
:param groupName: Exact model group name, or None for default group.
:param outlierLength: From -1.0 to < 0.0: negative proportion to exclude, 0.0: disable
outlier filter, > 0.0: absolute projection length above which data points are exclude,
None to reset to global default (0.0).
Function ensures value is valid.
"""
if outlierLength is not None:
if not isinstance(outlierLength, float):
outlierLength = self.getGroupOutlierLength(groupName)[0]
elif outlierLength < -1.0:
outlierLength = -1.0
self.setGroupSetting(groupName, self._outlierLengthToken, outlierLength)

def run(self, modelFileNameStem=None):
"""
Calculate data projections with current settings.
Expand Down
Loading

0 comments on commit f1abeaa

Please sign in to comment.