Skip to content

Commit

Permalink
modify duplicate fix and check functions
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohang007 committed Jul 27, 2023
1 parent 883e6dd commit 720ea89
Showing 1 changed file with 46 additions and 73 deletions.
119 changes: 46 additions & 73 deletions invcryrep/invcryrep/invcryrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def structure2structure_graph(self,structure):
print("ERROR - graph_method not implemented")
return structure_graph

def from_SLICES(self,SLICES):
def from_SLICES(self,SLICES,fix_duplicate_edge=False):
"""
extract edge_indices, to_jimages and atom_types from decoding a SLICES string
"""
Expand Down Expand Up @@ -466,11 +466,41 @@ def from_SLICES(self,SLICES):
to_jimages[i,j]=1
else:
raise Exception("Error: wrong edge label")
if fix_duplicate_edge:
edge_data_ascending=[]
for i in range(len(edge_indices)):
if edge_indices[i][0]<=edge_indices[i][1]:
edge_data_ascending.append(list(edge_indices[i])+list(to_jimages[i]))
else:
edge_data_ascending.append([edge_indices[i][1],edge_indices[i][0]]+list(np.array(to_jimages[i])*-1))
edge_data_ascending=np.array(edge_data_ascending,dtype=int)
edge_data_ascending_unique=np.unique(edge_data_ascending,axis=0)
edge_indices=edge_data_ascending_unique[:,:2]
to_jimages=edge_data_ascending_unique[:,2:]
self.edge_indices=edge_indices
self.to_jimages=to_jimages
self.atom_types=np.array([int(periodic_data.loc[periodic_data["symbol"]==i].values[0][0]) for i in self.atom_symbols])

def check_SLICES(self,SLICES):
def to_SLICES(self):
def get_slices3(atom_symbols,edge_indices,to_jimages):
SLICES=''
for i in atom_symbols:
SLICES+=i+' '
for i in range(len(edge_indices)):
SLICES+=str(edge_indices[i][0])+' '+str(edge_indices[i][1])+' '
for j in to_jimages[i]:
if j==-1:
SLICES+='- '
if j==0:
SLICES+='o '
if j==1:
SLICES+='+ '
return SLICES
atom_symbols = [str(ElementBase.from_Z(i)) for i in self.atom_types]
return get_slices3(atom_symbols,self.edge_indices,self.to_jimages)


def check_SLICES(self,SLICES,dupli_check=True):
try:
self.from_SLICES(SLICES)
except:
Expand Down Expand Up @@ -502,78 +532,21 @@ def check_SLICES(self,SLICES):
return False
#print(edge_index_covered)
# check dumplicates(flip)
edge_data_ascending=[]
for i in range(len(self.edge_indices)):
if self.edge_indices[i][0]<=self.edge_indices[i][1]:
edge_data_ascending.append(list(self.edge_indices[i])+list(self.to_jimages[i]))
else:
edge_data_ascending.append([self.edge_indices[i][1],self.edge_indices[i][0]]+list(np.array(self.to_jimages[i])*-1))
def remove_duplicate_arrays(arrays):
unique_arrays = []
for array in arrays:
if array not in unique_arrays:
unique_arrays.append(array)
return unique_arrays
if len(edge_data_ascending)>len(remove_duplicate_arrays(edge_data_ascending)):
return False
# strict case: (still not covering all cases)
if len(edge_index_covered[1])>=len(edge_index_covered[0]):
b_sub_a = [i for i in edge_index_covered[1] if i not in edge_index_covered[0]]
else:
b_sub_a = [i for i in edge_index_covered[0] if i not in edge_index_covered[1]]
a_add_b = edge_index_covered[0]+edge_index_covered[1]
if len(a_add_b)>=len(edge_index_covered[2]):
c_sub_ab = [i for i in a_add_b if i not in edge_index_covered[2]]
else:
c_sub_ab = [i for i in edge_index_covered[2] if i not in a_add_b]
#print(b_sub_a,c_sub_ab)
if len(b_sub_a)==0 or len(c_sub_ab)==0:
return False
try:
x_dat, net_voltage = self.convert_graph()
#print(x_dat,net_voltage)
net = Net(x_dat,dim=3)
net.voltage = net_voltage
#print(net.graph.edges)
# check the graph first (super fast)
net.simple_cycle_basis()
net.get_lattice_basis()
net.get_cocycle_basis()
except:
return False
return True

def check_SLICES_without_dupli(self,SLICES):
try:
self.from_SLICES(SLICES)
except:
return False
# make sure the rank of first homology group of graph >= 3, in order to get 3D embedding
G = nx.MultiGraph()
G.add_nodes_from([i for i in range(len(self.atom_types))])
G.add_edges_from(self.edge_indices) # convert to MultiGraph (from MultiDiGraph) !MST can only deal with MultiGraph
mst = tree.minimum_spanning_edges(G, algorithm="kruskal", data=False)
b=G.size()-len(list(mst)) # rank of first homology group of graph X(V,E); rank H1(X,Z) = |E| − |E1|
#print(b)
if b < 3:
return False
# check if all nodes has been covered by edges
nodes_covered=[]
for i in self.edge_indices:
nodes_covered.append(i[0])
nodes_covered.append(i[1])
if len(set(nodes_covered))!=len(self.atom_types):
return False
# check if edge labels covers 3 dimension in at least 3 edges, in order to get 3D embedding
edge_index_covered=[[],[],[]]
for i in range(len(self.to_jimages)):
for j in range(3):
if self.to_jimages[i][j]!=0:
edge_index_covered[j].append(i)
for i in edge_index_covered:
if len(i)==0:
if dupli_check:
edge_data_ascending=[]
for i in range(len(self.edge_indices)):
if self.edge_indices[i][0]<=self.edge_indices[i][1]:
edge_data_ascending.append(list(self.edge_indices[i])+list(self.to_jimages[i]))
else:
edge_data_ascending.append([self.edge_indices[i][1],self.edge_indices[i][0]]+list(np.array(self.to_jimages[i])*-1))
def remove_duplicate_arrays(arrays):
unique_arrays = []
for array in arrays:
if array not in unique_arrays:
unique_arrays.append(array)
return unique_arrays
if len(edge_data_ascending)>len(remove_duplicate_arrays(edge_data_ascending)):
return False
#print(edge_index_covered)
# strict case: (still not covering all cases)
if len(edge_index_covered[1])>=len(edge_index_covered[0]):
b_sub_a = [i for i in edge_index_covered[1] if i not in edge_index_covered[0]]
Expand Down

0 comments on commit 720ea89

Please sign in to comment.