diff --git a/atomai/nets/fcnn.py b/atomai/nets/fcnn.py index 107416f9..5f52bdb3 100644 --- a/atomai/nets/fcnn.py +++ b/atomai/nets/fcnn.py @@ -20,6 +20,8 @@ class Unet(nn.Module): Builds a fully convolutional Unet-like neural network model Args: + n_channels: + Number of channels in the input image nb_classes: Number of classes in the ground truth nb_filters: @@ -48,6 +50,7 @@ class Unet(nn.Module): (to maintain symmetry between encoder and decoder) """ def __init__(self, + n_channels: int = 1, nb_classes: int = 1, nb_filters: int = 16, dropout: bool = False, @@ -64,7 +67,7 @@ def __init__(self, padding_values = dilation_values.copy() dropout_vals = [.1, .2, .1] if dropout else [0, 0, 0] self.c1 = ConvBlock( - 2, nbl[0], 1, nb_filters, + 2, nbl[0], n_channels, nb_filters, batch_norm=batch_norm ) self.c2 = ConvBlock( @@ -148,6 +151,8 @@ class dilnet(nn.Module): by utilizing a combination of regular and dilated convolutions Args: + n_channels: + Number of channels in the input image nb_classes: Number of classes in the ground truth nb_filters: @@ -167,6 +172,7 @@ class dilnet(nn.Module): """ def __init__(self, + n_channels: int = 1, nb_classes: int = 1, nb_filters: int = 25, dropout: bool = False, @@ -184,7 +190,7 @@ def __init__(self, padding_values_2 = dilation_values_2.copy() dropout_vals = [.3, .3] if dropout else [0, 0] self.c1 = ConvBlock( - 2, nbl[0], 1, nb_filters, + 2, nbl[0], n_channels, nb_filters, batch_norm=batch_norm ) self.at1 = DilatedBlock( @@ -231,6 +237,8 @@ class ResHedNet(nn.Module): Holistically nested edge detector with residual connections in each block Args: + n_channels: + Number of channels in the input layer nb_classes: Number of classes in the ground truth nb_filters: @@ -247,6 +255,7 @@ class ResHedNet(nn.Module): """ def __init__(self, + n_channels: int = 1, nb_classes: int = 1, nb_filters: int = 64, upsampling_mode: str = "bilinear", @@ -257,7 +266,7 @@ def __init__(self, super(ResHedNet, self).__init__() nbl = kwargs.get("layers", [3, 4, 5]) self.upsample = upsampling_mode - self.net1 = ResModule(2, nbl[0], 1, nb_filters, True) + self.net1 = ResModule(2, nbl[0], n_channels, nb_filters, True) self.net2 = nn.Sequential( nn.MaxPool2d(2, 2), ResModule(2, nbl[1], nb_filters, 2*nb_filters, True) @@ -302,6 +311,8 @@ class SegResNet(nn.Module): with residual blocks for semantic segmentation Args: + n_channels: + Number of channels in the input image nb_classes: Number of classes in the ground truth nb_filters: @@ -321,6 +332,7 @@ class SegResNet(nn.Module): ''' def __init__(self, + n_channels: int = 1, nb_classes: int = 1, nb_filters: int = 32, batch_norm: bool = True, @@ -333,7 +345,7 @@ def __init__(self, super(SegResNet, self).__init__() nbl = kwargs.get("layers", [2, 2, 2]) self.c1 = ConvBlock( - 2, 1, 1, nb_filters, batch_norm=batch_norm + 2, 1, n_channels, nb_filters, batch_norm=batch_norm ) self.c2 = ResModule( 2, nbl[0], nb_filters, nb_filters*2, batch_norm=batch_norm @@ -386,12 +398,14 @@ def init_fcnn_model(model: Union[Type[nn.Module], str], meta_state_dict = { 'model_type': 'Seg', model: 'custom', 'nb_classes': nb_classes} return model, meta_state_dict + n_channels = kwargs.get('n_channels', 1) batch_norm = kwargs.get('batch_norm', True) dropout = kwargs.get('dropout', False) upsampling = kwargs.get('upsampling', "bilinear") meta_state_dict = { 'model_type': 'seg', 'model': model, + 'n_channels': n_channels, 'nb_classes': nb_classes, 'batch_norm': batch_norm, 'dropout': dropout, @@ -402,7 +416,7 @@ def init_fcnn_model(model: Union[Type[nn.Module], str], nb_filters = kwargs.get('nb_filters', 16) layers = kwargs.get("layers", [1, 2, 2, 3]) net = Unet( - nb_classes, nb_filters, dropout, + n_channels, nb_classes, nb_filters, dropout, batch_norm, upsampling, with_dilation, layers=layers ) @@ -411,7 +425,7 @@ def init_fcnn_model(model: Union[Type[nn.Module], str], nb_filters = kwargs.get('nb_filters', 25) layers = kwargs.get("layers", [1, 3, 3, 1]) net = dilnet( - nb_classes, nb_filters, + n_channels, nb_classes, nb_filters, dropout, batch_norm, upsampling, layers=layers ) @@ -419,14 +433,14 @@ def init_fcnn_model(model: Union[Type[nn.Module], str], nb_filters = kwargs.get('nb_filters', 32) layers = kwargs.get("layers", [2, 2, 2]) net = SegResNet( - nb_classes, nb_filters, + n_channels, nb_classes, nb_filters, batch_norm, upsampling, layers=layers ) elif isinstance(model, str) and model == 'ResHedNet': nb_filters = kwargs.get('nb_filters', 64) layers = kwargs.get("layers", [3, 4, 5]) net = ResHedNet( - nb_classes, nb_filters, + n_channels, nb_classes, nb_filters, upsampling, layers=layers ) else: diff --git a/atomai/utils/graphx.py b/atomai/utils/graphx.py index d7930f15..1fea1d5a 100644 --- a/atomai/utils/graphx.py +++ b/atomai/utils/graphx.py @@ -61,7 +61,8 @@ class Graph: """ def __init__(self, coordinates: np.ndarray, - map_dict: Dict) -> None: + map_dict: Dict, + px2ang: float = 1) -> None: """ Initializes a graph object """ @@ -76,6 +77,8 @@ def __init__(self, coordinates: np.ndarray, v = Node(i, coords[:-1].tolist(), map_dict[coords[-1]]) self.vertices.append(v) self.coordinates = coordinates + self.coordinates_ang = deepcopy(coordinates) + self.coordinates_ang[:, :-1] = self.coordinates[:, :-1] * px2ang self.map_dict = map_dict self.size = len(coordinates) self.rings = [] @@ -87,6 +90,10 @@ def find_neighbors(self, **kwargs: float): Identifies neighbors of each graph node Args: + **max_neighbors(int): + This is the maximum number of neighbors each node can have, + ususally used to form the graph with only nearest neighbors + Default is -1 which means it will find all the neighbors **expand (float): coefficient determining the maximum allowable expansion of atomic bonds when constructing a graph. For example, the two @@ -97,34 +104,59 @@ def find_neighbors(self, **kwargs: float): del v.neighbors[:] Rij = get_interatomic_r e = kwargs.get("expand", 1.2) - tree = spatial.cKDTree(self.coordinates[:, :3]) - uval = np.unique(self.coordinates[:, -1]) + max_neighbors = kwargs.get("max_neighbors", -1) + tree = spatial.cKDTree(self.coordinates_ang[:, :3]) + uval = np.unique(self.coordinates_ang[:, -1]) if len(uval) == 1: rmax = Rij([self.map_dict[uval[0]], self.map_dict[uval[0]]], e) - neighbors = tree.query_ball_point(self.coordinates[:, :3], r=rmax) + if max_neighbors == -1: + neighbors = tree.query_ball_point(self.coordinates_ang[:, :3], r=rmax) + else: + _, neighbors = tree.query(self.coordinates_ang[:, :3], k=max_neighbors+1, distance_upper_bound = rmax) for v, nn in zip(self.vertices, neighbors): for n in nn: - if self.vertices[n] != v: - v.neighbors.append(self.vertices[n]) - v.neighborscopy.append(self.vertices[n]) + if not n >= len(self.vertices): + if self.vertices[n] != v: + v.neighbors.append(self.vertices[n]) + v.neighborscopy.append(self.vertices[n]) + else: uval = [self.map_dict[u] for u in uval] apairs = [(p[0], p[1]) for p in itertools.product(uval, repeat=2)] rij = [Rij([a[0], a[1]], e) for a in apairs] rmax = np.max(rij) rij = dict(zip(apairs, rij)) - for v, coords in zip(self.vertices, self.coordinates): + for v, coords in zip(self.vertices, self.coordinates_ang): atom1 = self.map_dict[coords[-1]] - nn = tree.query_ball_point(coords[:3], r=rmax) - for n, coords2 in zip(nn, self.coordinates[nn]): - if self.vertices[n] != v: - atom2 = self.map_dict[coords2[-1]] - eucldist = np.linalg.norm( - coords[:3] - coords2[:3]) - if eucldist <= rij[(atom1, atom2)]: - v.neighbors.append(self.vertices[n]) - v.neighborscopy.append(self.vertices[n]) - + if max_neighbors == -1: + nn = tree.query_ball_point(coords[:3], r=rmax) + else: + _, nn = tree.query(coords[:3], k=max_neighbors+1, distance_upper_bound = rmax) + + for n in nn: + if not n >= len(self.vertices): + coords2 = self.coordinates_ang[n] + if self.vertices[n] != v: + atom2 = self.map_dict[coords2[-1]] + eucldist = np.linalg.norm( + coords[:3] - coords2[:3]) + if eucldist <= rij[(atom1, atom2)]: + v.neighbors.append(self.vertices[n]) + v.neighborscopy.append(self.vertices[n]) + + #Making the graph symmetric when max_neighbors is used + for v in self.vertices: + id = v.id + rem_ids = [] + for nn in v.neighbors: + nn_neighbors_list = [nn.neighbors[l].id for l in range(len(nn.neighbors))] + if id not in nn_neighbors_list: + rem_ids.append(nn) + + for rem_id in rem_ids: + v.neighbors.remove(rem_id) + + def find_rings(self, v: Type[Node], rings: List[List[Type[Node]]] = [],