root/imbot/Graph.py

Revision 99, 6.8 kB (checked in by njohri2, 3 years ago)

Changed ngram_probability formula in Graph.py, merged ngram functions in SABBrain, and added priority queue as a heap to search for replies

Line 
1 # third party dependencies
2 import networkx as nx
3 import matplotlib.pyplot as plot
4
5 (i_word, i_tag) = (0, 1)
6 (i_label, i_cache) = (0, 1)
7 (i_distance, i_count) = (0, 1)
8 (i_pre, i_post) = (-1, 1)
9
10 class DistanceGraph(object):
11         def __init__(self, log, file_path=r"Brain/distances.pickle"):
12                 self.graph = nx.DiGraph()
13                 self.log = log
14                 self.file_path = file_path
15                
16                 # start node
17                 self.start = ("start", "start")
18                 self.graph.add_node(self.start)
19                
20                 # end node
21                 self.end = ("end", "end")
22                 self.graph.add_node(self.end)
23        
24         def clear(self):
25                 ''' Remove all nodes and edges from the distance graph '''
26                 self.__init__(self.log, self.file_path)
27                 print self.file_path, "graph cleared"
28        
29         def load(self):
30                 ''' Load distance graph nodes and edges from a pickle file '''
31                 try:
32                         self.graph = nx.read_gpickle(self.file_path)
33                         return True
34                
35                 except IOError:
36                         self.clear()
37                         return False
38        
39         def save(self):
40                 ''' Save distance graph nodes and edges to a pickle file '''
41                 nx.write_gpickle(self.graph, self.file_path)
42        
43         def add_node(self, node):
44                 ''' Add a node to the graph '''
45                 self.graph.add_node(node)
46                 self.log.add("added distance node: %s" % str(node))
47        
48         def has_node(self, node):
49                 ''' Checks if a node exists in the graph '''
50                 return self.graph.has_node(node)
51        
52         def add_edge(self, node1, node2, distance):
53                 ''' Create or update the distance edge between two nodes '''
54                 data = dict() if not self.graph.has_edge(node1, node2) else self.graph.get_edge(node1, node2)
55                 data[distance] = 1 if not data.has_key(distance) else data[distance] + 1
56                 data[0] = 1 if not data.has_key(0) else data[0] + 1
57                
58                 self.graph.add_edge(node1, node2, data)
59                 self.log.add("added distance edge: %s -- %i -- %s" % (node1, distance, node2))
60                
61         def edge(self, node1, node2, distance):
62                 ''' Returns the value stored on the edge between 'node1' 'node2' '''
63                 return self.graph.get_edge(node1, node2)[distance] if self.graph.has_edge(node1, node2) and (self.graph.get_edge(node1, node2).has_key(distance)) else 0
64        
65         def predecessors(self, node):
66                 ''' Returns the nodes that have edges pointing at 'node' '''
67                 return (self.graph.predecessors(node) if self.graph.has_node(node) else [])
68        
69         def successors(self, node):
70                 ''' Returns the nodes that have edges pointed to by edges leaving 'node' '''
71                 return (self.graph.successors(node) if self.graph.has_node(node) else [])
72                
73         def ngram_probability(self, givenNode, testNode, distance, pre_or_post):
74                 ''' Returns the probability that the testNode either precedes or follows the givenNode by the given distance'''         
75                 link_occurences = float(self.edge(givenNode, testNode, 0) if pre_or_post == i_post else self.edge(testNode, givenNode, 0))
76                                        
77                 if link_occurences > 0.0:
78                         distance_occurences = float(self.edge(givenNode, testNode, distance) if pre_or_post == i_post else self.edge(testNode, givenNode, distance))
79                         ngram_prb = (distance_occurences) / (link_occurences)
80                         return ngram_prb
81                 else:
82                         return 0.0
83                
84         def draw_graph(self):
85                 ''' Draws the graph '''
86                 # positions for all nodes
87                 pos = nx.spring_layout(self.graph)
88                                
89                 # draw the graph
90                 nx.draw_networkx_nodes(self.graph, pos, node_size=2000)
91                 nx.draw_networkx_edges(self.graph, pos, edgelist=[(u,v) for (u,v,d) in self.graph.edges(data=True)], width=1, alpha=1.0, edge_color='r')
92                        
93                 # turn off x and y axes labels
94                 plot.xticks([])
95                 plot.yticks([])
96                        
97                 # display graph
98                 plot.show()
99
100 class AssociationGraph(object):
101         def __init__(self, log, file_path=r"Brain/associations.pickle"):
102                 self.graph = nx.LabeledGraph()
103                 self.log = log
104                 self.file_path = file_path
105                 self.max_best_edges = 10
106        
107         def clear(self):
108                 ''' Remove all nodes and edges from the association graph '''
109                 self.__init__(self.log)
110                 print self.file_path, "graph cleared"
111        
112         def load(self):
113                 ''' Load association graph nodes and edges from a pickle file '''
114                 try:
115                         self.graph = nx.read_gpickle(self.file_path)
116                         return True
117                
118                 except IOError:
119                         self.clear()
120                         return False
121        
122         def save(self):
123                 ''' Save association graph nodes and edges to a pickle file '''
124                 nx.write_gpickle(self.graph, self.file_path)
125        
126         def add_edge(self, node1, node2, association):
127                 ''' Create or update the association edge between two nodes '''
128                 if not self.graph.has_node(node1):
129                                 self.graph.add_node(node1, [1, []])
130                 if not self.graph.has_node(node2):
131                         self.graph.add_node(node2, [1, []])
132                 association = association + (self.graph.get_edge(node1, node2) if self.graph.has_edge(node1, node2) else 0.0)
133                 self.graph.add_edge(node1, node2, association)
134                 self.log.add("added association edge: %s -- %i -- %s" % (node1, association, node2))
135                
136                 # update our 'best edges' cache for node1
137                 best_edges = self.graph.label[node1][i_cache]
138                 best_edges.append((association, node2))
139                 self.graph.label[node1][i_cache] = sorted(best_edges)[:self.max_best_edges]
140                
141                 # update our 'best edges' cache for node2
142                 best_edges = self.graph.label[node2][i_cache]
143                 best_edges.append((association, node1))
144                 self.graph.label[node2][i_cache] = sorted(best_edges)[:self.max_best_edges]
145        
146         def best_edges(self, node):
147                 ''' Returns a list of (association, node) tuples where 'association' is large and points to 'node' '''
148                 return self.graph.label[node2][i_cache]
149        
150         def increment(self, node):
151                 ''' Increment the number of times we've seen 'node' and store the result in its label '''
152                 if not self.graph.has_node(node):
153                         self.graph.add_node(node, [1, []])
154                         self.log.add("added association node %s = 1" % str(node))
155                 else:
156                         existing_value = self.graph.label[node][i_label]
157                         self.graph.label[node][i_label] = existing_value + 1
158                         self.log.add("updated association node %s = %d" % (str(node), existing_value + 1))
159        
160         def has_edge(self, node1, node2):
161                 ''' Returns True if an edge between 'node1' and 'node2' exists, otherwise False '''
162                 return self.graph.has_edge(node1, node2)
163
164         def get_edge(self, node1, node2):
165                 ''' Returns the value of the edge between 'node1' and 'node2' '''
166                 if not self.graph.has_node(node1) or not self.graph.has_node(node2):
167                         self.log.add("WARNING: get_edge(%s, %s) returned the value of nodes that didn't previously exist!" % (str(node1), str(node2)))
168                 if not self.has_edge(node1, node2):
169                         return 0
170                 return self.graph.get_edge(node1, node2)
171                
172         def label(self, node):
173                 ''' Returns the label of the node '''
174                 return (self.graph.label[node][i_label] if self.graph.has_node(node) else 0)
175                
176         def draw_graph(self):
177                 ''' Draws the graph '''
178                 # positions for all nodes
179                 pos = nx.spring_layout(self.graph)
180                                
181                 # draw the graph
182                 nx.draw_networkx_nodes(self.graph, pos, node_size=2000)
183                 nx.draw_networkx_edges(self.graph, pos, edgelist=[(u,v) for (u,v,d) in self.graph.edges(data=True)], width=1, alpha=1.0, edge_color='r')
184                 nx.draw_networkx_labels(self.graph, pos, font_size=8, font_family='sans-serif')
185                        
186                 # turn off x and y axes labels
187                 plot.xticks([])
188                 plot.yticks([])
189                        
190                 # display graph
191                 plot.show()
192                
193
194
195
Note: See TracBrowser for help on using the browser.