updated ChangeLog
[swftools.git] / lib / graphcut.c
1 /*
2     graphcut- a graphcut implementation based on the Boykov Kolmogorov algorithm 
3
4     Part of the swftools package.
5
6     Copyright (c) 2007,2008,2009 Matthias Kramm <kramm@quiss.org>
7
8     This program is free software: you can redistribute it and/or modify
9     it under the terms of the GNU General Public License as published by
10     the Free Software Foundation, either version 3 of the License, or
11     (at your option) any later version.
12
13     This program is distributed in the hope that it will be useful,
14     but WITHOUT ANY WARRANTY; without even the implied warranty of
15     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16     GNU General Public License for more details.
17
18     You should have received a copy of the GNU General Public License
19     along with this program.  If not, see <http://www.gnu.org/licenses/>.
20 */
21
22 #include <stdlib.h>
23 #include <stdio.h>
24 #include <math.h>
25 #include <memory.h>
26 #include "graphcut.h"
27 #include "../mem.h"
28
29 //#define DEBUG
30
31 //#define CHECKS
32
33 #ifdef DEBUG
34 #define DBG
35 #include <assert.h>
36 #else
37 #define DBG if(0)
38 #define assert(x) (x)
39 #endif
40
41 #define ACTIVE 0x10
42 #define IN_TREE 0x20
43
44 #define TWOTREES
45
46 typedef struct _posqueue_entry {
47     node_t*pos;
48     struct _posqueue_entry*next;
49 } posqueue_entry_t;
50
51 typedef struct _posqueue {
52     posqueue_entry_t*list;
53 } posqueue_t;
54
55 typedef struct _graphcut_workspace {
56     unsigned char*flags1;
57     unsigned char*flags2;
58     halfedge_t**back;
59     graph_t*graph;
60     node_t*pos1;
61     node_t*pos2;
62     posqueue_t*queue1;
63     posqueue_t*queue2;
64     posqueue_t*tmpqueue;
65 } graphcut_workspace_t;
66
67 static posqueue_t*posqueue_new() 
68 {
69     posqueue_t*m = (posqueue_t*)malloc(sizeof(posqueue_t));
70     memset(m, 0, sizeof(posqueue_t));
71     return m;
72 }
73 static void posqueue_delete(posqueue_t*q)
74 {
75     posqueue_entry_t*l = q->list;
76     while(l) {
77         posqueue_entry_t*next = l->next;
78         free(l);
79         l = next;
80     }
81     free(q);
82 }
83 static inline void posqueue_addpos(posqueue_t*queue, node_t*pos)
84 {
85     posqueue_entry_t*old = queue->list;
86     queue->list = malloc(sizeof(posqueue_entry_t));
87     queue->list->pos = pos;
88     queue->list->next = old;
89 }
90 static inline node_t* posqueue_extract(posqueue_t*queue)
91 {
92     posqueue_entry_t*item = queue->list;
93     node_t*pos;
94     if(!item)
95         return 0;
96     pos = item->pos;
97     queue->list = queue->list->next;
98     free(item);
99     return pos;
100 }
101 static inline int posqueue_notempty(posqueue_t*queue)
102 {
103     return (int)queue->list;
104 }
105
106 #define NR(p) ((p)->nr)
107
108 static void posqueue_print(graphcut_workspace_t*w, posqueue_t*queue)
109 {
110     posqueue_entry_t*e = queue->list;
111     while(e) {
112         halfedge_t*back = w->back[NR(e->pos)];
113         printf("%d(%d) ", NR(e->pos), back?NR(back->fwd->node):-1);
114         e = e->next;
115     }
116     printf("\n");
117 }
118 static void posqueue_purge(posqueue_t*queue)
119 {
120     posqueue_entry_t*e = queue->list;
121     while(e) {
122         posqueue_entry_t*next = e->next;
123         e->next = 0;free(e);
124         e = next;
125     }
126     queue->list = 0;
127 }
128
129 graph_t* graph_new(int num_nodes)
130 {
131     graph_t*graph = rfx_calloc(sizeof(graph_t));
132     graph->num_nodes = num_nodes;
133     graph->nodes = rfx_calloc(sizeof(node_t)*num_nodes);
134     int t;
135     for(t=0;t<num_nodes;t++) {
136         graph->nodes[t].nr = t;
137     }
138     return graph;
139 }
140
141 void graph_delete(graph_t*graph)
142 {
143     int t;
144     for(t=0;t<graph->num_nodes;t++) {
145         halfedge_t*e = graph->nodes[t].edges;
146         while(e) {
147             halfedge_t*next = e->next;
148             free(e);
149             e = next;
150         }
151     }
152     free(graph->nodes);graph->nodes=0;
153     free(graph);
154 }
155
156 static graphcut_workspace_t*graphcut_workspace_new(graph_t*graph, node_t*pos1, node_t*pos2)
157 {
158     graphcut_workspace_t*workspace = malloc(sizeof(graphcut_workspace_t));
159     workspace->flags1 = rfx_calloc(graph->num_nodes);
160     workspace->flags2 = rfx_calloc(graph->num_nodes);
161     workspace->back = rfx_calloc(graph->num_nodes*sizeof(halfedge_t*));
162     workspace->pos1 = pos1;
163     workspace->pos2 = pos2;
164     workspace->graph = graph;
165     workspace->queue1 = posqueue_new();
166     workspace->queue2 = posqueue_new();
167     workspace->tmpqueue = posqueue_new();
168     return workspace;
169 }
170 static void graphcut_workspace_delete(graphcut_workspace_t*w) 
171 {
172     posqueue_delete(w->queue1);w->queue1=0;
173     posqueue_delete(w->queue2);w->queue2=0;
174     posqueue_delete(w->tmpqueue);w->tmpqueue=0;
175     if(w->flags1) free(w->flags1);w->flags1=0;
176     if(w->flags2) free(w->flags2);w->flags2=0;
177     if(w->back) free(w->back);w->back=0;
178     free(w);
179 }
180
181 typedef struct _path {
182     node_t**pos;
183     halfedge_t**dir;
184     unsigned char*firsthalf;
185     int length;
186 } path_t;
187
188 static path_t*path_new(int len)
189 {
190     path_t*p = malloc(sizeof(path_t));
191     p->pos = malloc(sizeof(node_t*)*len);
192     p->dir = malloc(sizeof(halfedge_t*)*len);
193     p->firsthalf = malloc(sizeof(unsigned char)*len);
194     p->length = len;
195     return p;
196 }
197 static void path_delete(path_t*path)
198 {
199     free(path->pos);path->pos = 0;
200     free(path->dir);path->dir = 0;
201     free(path->firsthalf);path->firsthalf = 0;
202     free(path);
203 }
204
205 static path_t*extract_path(graphcut_workspace_t*work, unsigned char*mytree, unsigned char*othertree, node_t*pos, node_t*newpos, halfedge_t*dir)
206 {
207     int t;
208     node_t*p = pos;
209     node_t*nodes = work->graph->nodes;
210     int len1 = 0;
211     /* walk up tree1 */
212     DBG printf("walk back up (1) to %d\n", NR(work->pos1));
213     while(p != work->pos1) {
214         halfedge_t*back = work->back[NR(p)];
215         DBG printf("walk backward (1): %d %d\n", NR(p), back?NR(back->fwd->node):-1);
216         node_t*old = p;
217         p = work->back[NR(p)]->fwd->node;
218         assert(p!=old);
219         len1++;
220     }
221     p = newpos;
222     int len2 = 0;
223     DBG printf("walk back up (2) to %d\n", NR(work->pos2));
224     /* walk up tree2 */
225     while(p != work->pos2) {
226         DBG printf("walk backward (2): %d\n", NR(p));
227         p = work->back[NR(p)]->fwd->node;
228         len2++;
229     }
230     path_t*path = path_new(len1+len2+2);
231
232     t = len1;
233     path->pos[t] = p = pos;
234     path->dir[t] = dir;
235     path->firsthalf[t] = 1;
236     while(p != work->pos1) {
237         assert(mytree[NR(p)]&IN_TREE);
238         halfedge_t*dir = work->back[NR(p)];
239         assert(dir->node == p);
240         p = dir->fwd->node;
241         t--;
242         path->pos[t] = p;
243         path->dir[t] = dir->fwd;
244         path->firsthalf[t] = 1;
245     }
246     assert(!t);
247
248     t = len1+1;
249
250     p = newpos;
251     while(p != work->pos2) {
252         assert(othertree[NR(p)]&IN_TREE);
253         halfedge_t*dir = work->back[NR(p)];
254         path->pos[t] = p;
255         path->dir[t] = dir;
256         path->firsthalf[t] = 0;
257         p = dir->fwd->node;
258         t++;
259     }
260
261     /* terminator */
262     path->pos[t] = p;
263     path->dir[t] = 0; // last node
264     path->firsthalf[t] = 0;
265
266     assert(t == len1+len2+1);
267     return path;
268 }
269
270 static void path_print(path_t*path)
271 {
272     int t;
273     for(t=0;t<path->length;t++) {
274         node_t*n = path->pos[t];
275         printf("%d (firsthalf: %d)", NR(n), path->firsthalf[t]);
276         if(t<path->length-1) {
277             printf(" -(%d/%d)-> \n", 
278                     path->dir[t]->used,
279                     path->dir[t]->fwd->used);
280         } else {
281             printf("\n");
282         }
283     }
284
285     for(t=0;t<path->length-1;t++) {
286         if(path->firsthalf[t]==path->firsthalf[t+1]) {
287             assert(( path->firsthalf[t] && path->dir[t]->used) || 
288                    (!path->firsthalf[t] && path->dir[t]->fwd->used));
289         }
290     }
291     printf("\n");
292 }
293
294
295 static void workspace_print(graphcut_workspace_t*w)
296 {
297     printf("queue1: ");posqueue_print(w, w->queue1);
298     printf("queue2: ");posqueue_print(w, w->queue2);
299 }
300
301 static void myassert(graphcut_workspace_t*w, char assertion, const char*file, int line, const char*func)
302 {
303     if(!assertion) {
304         printf("Assertion %s:%d (%s) failed:\n", file, line, func);
305         workspace_print(w);
306         exit(0);
307     }
308 }
309
310 #define ASSERT(w,c) {myassert(w,c,__FILE__,__LINE__,__func__);}
311
312 static path_t* expand_pos(graphcut_workspace_t*w, posqueue_t*queue, node_t*pos, char reverse, unsigned char*mytree, unsigned char*othertree)
313 {
314     graph_t*graph = w->graph;
315     int dir;
316     if((mytree[NR(pos)]&(IN_TREE|ACTIVE)) != (IN_TREE|ACTIVE)) {
317         /* this node got deleted or marked inactive in the meantime. ignore it */
318         DBG printf("node %d is deleted or inactive\n", NR(pos));
319         return 0;
320     }
321
322     halfedge_t*e = pos->edges;
323     for(;e;e=e->next) {
324         node_t*newpos = e->fwd->node;
325         weight_t weight = reverse?e->fwd->weight:e->weight;
326         if(mytree[NR(newpos)]) continue; // already known
327
328         if(weight) {
329             if(othertree[NR(newpos)]) {
330                 DBG printf("found connection: %d connects to %d\n", NR(pos), NR(newpos));
331                 posqueue_addpos(queue, pos); mytree[NR(pos)] |= ACTIVE; // re-add, this vertex might have other connections
332
333                 path_t*path;
334                 if(reverse) {
335                     path = extract_path(w, othertree, mytree, newpos, pos, e->fwd);
336                 } else {
337                     path = extract_path(w, mytree, othertree, pos, newpos, e);
338                 }
339                 return path;
340             } else {
341                 DBG printf("advance from %d to new pos %d\n", NR(pos), NR(newpos));
342                 w->back[NR(newpos)] = e->fwd;
343                 e->used = 1;
344                 posqueue_addpos(queue, newpos); mytree[NR(newpos)] |= ACTIVE|IN_TREE; // add
345             }
346         }
347     }
348     /* if we can't expand this node anymore, it's now an inactive node */
349     mytree[NR(pos)] &= ~ACTIVE;
350     return 0;
351 }
352
353 static int node_count_edges(node_t*node)
354 {
355     halfedge_t*e = node->edges;
356     int num = 0;
357     while(e) {
358         num++;
359         e = e->next;
360     }
361     return num;
362 }
363
364 static void bool_op(graphcut_workspace_t*w, unsigned char*flags, node_t*pos, unsigned char and, unsigned char or)
365 {
366     posqueue_t*q = w->tmpqueue;
367     posqueue_purge(q);
368     posqueue_addpos(q, pos);
369
370     while(posqueue_notempty(q)) {
371         node_t*p = posqueue_extract(q);
372         flags[NR(p)] = (flags[NR(p)]&and)|or;
373         halfedge_t*e = p->edges;
374         while(e) {
375             if(e->used) {
376                 posqueue_addpos(q, e->fwd->node);
377             }
378             e = e->next;
379         }
380     }
381 }
382
383 static weight_t decrease_weights(graph_t*map, path_t*path)
384 {
385     int t;
386     assert(path->length);
387
388     weight_t min = path->dir[0]->weight;
389     for(t=0;t<path->length-1;t++) {
390         int w = path->dir[t]->weight;
391         DBG printf("%d->%d (%d)\n", NR(path->dir[t]->node), NR(path->dir[t]->fwd->node), w);
392         if(t==0 || w < min) min = w;
393     }
394     assert(min);
395     if(min<=0) 
396         return 0;
397
398     for(t=0;t<path->length-1;t++) {
399         path->dir[t]->weight-=min;
400         path->dir[t]->fwd->weight+=min;
401     }
402     return min;
403 }
404
405 static int reconnect(graphcut_workspace_t*w, unsigned char*flags, node_t*pos, char reverse)
406 {
407     graph_t*graph = w->graph;
408
409     halfedge_t*e = pos->edges;
410     for(;e;e=e->next) {
411         node_t*newpos = e->fwd->node;
412         int weight;
413         if(!reverse) {
414             weight = e->fwd->weight;
415         } else {
416             weight = e->weight;
417         }
418         if(weight && (flags[NR(newpos)]&IN_TREE)) {
419             DBG printf("successfully reconnected node %d to %d (%d->%d) (reverse:%d)\n", 
420                     NR(pos), NR(newpos), NR(e->node), NR(e->fwd->node), reverse);
421
422             w->back[NR(pos)] = e;
423             e->fwd->used = 1;
424             return 1;
425         }
426     }
427     return 0;
428 }
429
430 static void clear_node(graphcut_workspace_t*w, node_t*n)
431 {
432     w->flags1[NR(n)] = 0;
433     w->flags2[NR(n)] = 0;
434     w->back[NR(n)] = 0;
435     halfedge_t*e = n->edges;
436     while(e) {e->used = 0;e=e->next;}
437 }
438
439 static void destroy_subtree(graphcut_workspace_t*w, unsigned char*flags, node_t*pos, posqueue_t*posqueue)
440 {
441     DBG printf("destroying subtree starting with %d\n", NR(pos));
442
443     posqueue_t*q = w->tmpqueue;
444     posqueue_purge(q);
445     posqueue_addpos(q, pos);
446
447     while(posqueue_notempty(q)) {
448         node_t*p = posqueue_extract(q);
449         halfedge_t*e = p->edges;
450         while(e) {
451             node_t*newpos = e->fwd->node;       
452             if(e->used) {
453                 posqueue_addpos(q, newpos);
454             } else if((flags[NR(newpos)]&(ACTIVE|IN_TREE)) == IN_TREE) {
455                 // re-activate all nodes that surround our subtree.
456                 // TODO: we should check the weight of the edge from that other
457                 // node to our node. if it's zero, we don't need to activate that node.
458                 posqueue_addpos(posqueue, newpos);
459                 flags[NR(newpos)]|=ACTIVE;
460             }
461             e = e->next;
462         }
463        
464         clear_node(w, p);
465         DBG printf("removed pos %d\n", NR(p));
466     }
467 }
468
469 static void combust_tree(graphcut_workspace_t*w, posqueue_t*q1, posqueue_t*q2, path_t*path)
470 {
471     graph_t*graph = w->graph;
472     int t;
473     for(t=0;t<path->length-1 && path->firsthalf[t+1];t++) {
474         node_t*pos = path->pos[t];
475         halfedge_t*dir = path->dir[t];
476         node_t*newpos = dir->fwd->node;
477         if(!dir->weight) {
478             /* disconnect node */
479             DBG printf("remove link %d -> %d from tree 1\n", NR(pos), NR(newpos));
480            
481             dir->used = 0;
482             w->flags1[NR(newpos)] &= ACTIVE;
483             bool_op(w, w->flags1, newpos, ~IN_TREE, 0);
484
485             /* try to reconnect the path to some other tree part */
486             if(reconnect(w, w->flags1, newpos, 0)) {
487                 bool_op(w, w->flags1, newpos, ~0, IN_TREE);
488             } else {
489                 destroy_subtree(w, w->flags1, newpos, q1);
490                 break;
491             }
492         }
493     }
494
495     for(t=path->length-1;t>0 && !path->firsthalf[t-1];t--) {
496         node_t*pos = path->pos[t];
497         node_t*newpos = path->pos[t-1];
498         halfedge_t*dir = path->dir[t-1]->fwd;
499         node_t*newpos2 = dir->fwd->node;
500         assert(newpos == newpos2);
501         if(!dir->fwd->weight) {
502             /* disconnect node */
503             DBG printf("remove link %d->%d from tree 2\n", NR(pos), NR(newpos));
504
505             dir->used = 0;
506             w->flags2[NR(newpos)] &= ACTIVE;
507             bool_op(w, w->flags2, newpos, ~IN_TREE, 0);
508
509             /* try to reconnect the path to some other tree part */
510             if(reconnect(w, w->flags2, newpos, 1)) {
511                 bool_op(w, w->flags2, newpos, ~0, IN_TREE);
512             } else {
513                 destroy_subtree(w, w->flags2, newpos, q2);
514                 break;
515             }
516         }
517     }
518 }
519
520 static void check_graph(graph_t*g)
521 {
522     int t;
523     for(t=0;t<g->num_nodes;t++) {
524         assert(g->nodes[t].nr==t);
525         halfedge_t*e = g->nodes[t].edges;
526         while(e) {
527             assert(!e->used || !e->fwd->used);
528             e = e->next;
529         }
530     }
531 }
532
533 void graph_reset(graph_t*g)
534 {
535     int t;
536     for(t=0;t<g->num_nodes;t++) {
537         g->nodes[t].nr = t;
538         assert(g->nodes[t].nr==t);
539         halfedge_t*e = g->nodes[t].edges;
540         while(e) {
541             e->used = 0;
542             e->weight = e->init_weight;
543             e = e->next;
544         }
545     }
546 }
547
548 weight_t graph_maxflow(graph_t*graph, node_t*pos1, node_t*pos2)
549 {
550     int max_flow = 0;
551     graphcut_workspace_t* w = graphcut_workspace_new(graph, pos1, pos2);
552
553     graph_reset(graph);
554     DBG check_graph(graph);
555    
556     posqueue_addpos(w->queue1, pos1); w->flags1[pos1->nr] |= ACTIVE|IN_TREE; 
557     posqueue_addpos(w->queue2, pos2); w->flags2[pos2->nr] |= ACTIVE|IN_TREE; 
558     DBG workspace_print(w);
559   
560     while(1) {
561         path_t*path;
562         while(1) {
563             char done1=0,done2=0;
564             node_t* p1 = posqueue_extract(w->queue1);
565             if(!p1) {
566                 graphcut_workspace_delete(w);
567                 return max_flow;
568             }
569             DBG printf("extend 1 from %d (%d edges)\n", NR(p1), node_count_edges(p1));
570             path = expand_pos(w, w->queue1, p1, 0, w->flags1, w->flags2);
571             if(path)
572                 break;
573             DBG workspace_print(w);
574            
575 #ifdef TWOTREES
576             node_t* p2 = posqueue_extract(w->queue2);
577             if(!p2) {
578                 graphcut_workspace_delete(w);
579                 return max_flow;
580             }
581             DBG printf("extend 2 from %d (%d edges)\n", NR(p2), node_count_edges(p2));
582             path = expand_pos(w, w->queue2, p2, 1, w->flags2, w->flags1);
583             if(path)
584                 break;
585             DBG workspace_print(w);
586 #endif
587
588         }
589         DBG printf("found connection between tree1 and tree2\n");
590         DBG path_print(path);
591
592         DBG printf("decreasing weights\n");
593         max_flow += decrease_weights(graph, path);
594         DBG workspace_print(w);
595
596         DBG printf("destroying trees\n");
597         combust_tree(w, w->queue1, w->queue2, path);
598         DBG workspace_print(w);
599
600         DBG check_graph(w->graph);
601
602         path_delete(path);
603     }
604     graphcut_workspace_delete(w);
605     return max_flow;
606 }
607
608 halfedge_t*graph_add_edge(node_t*from, node_t*to, weight_t forward_weight, weight_t backward_weight)
609 {
610     halfedge_t*e1 = (halfedge_t*)rfx_calloc(sizeof(halfedge_t));
611     halfedge_t*e2 = (halfedge_t*)rfx_calloc(sizeof(halfedge_t));
612     e1->fwd = e2;
613     e2->fwd = e1;
614     e1->node = from;
615     e2->node = to;
616     e1->init_weight = forward_weight;
617     e2->init_weight = backward_weight;
618     e1->weight = forward_weight;
619     e2->weight = backward_weight;
620
621     e1->next = from->edges;
622     from->edges = e1;
623     e2->next = to->edges;
624     to->edges = e2;
625     return e1;
626 }
627
628 static void do_dfs(node_t*n, int color)
629 {
630     int t;
631     n->tmp = color;
632     halfedge_t*e = n->edges;
633     while(e) {
634         if(e->fwd->node->tmp<0)
635             do_dfs(e->fwd->node, color);
636         e = e->next;
637     }
638 }
639
640 int graph_find_components(graph_t*g)
641 {
642     int t;
643     int count = 0;
644     for(t=0;t<g->num_nodes;t++) {
645         g->nodes[t].tmp = -1;
646     }
647     for(t=0;t<g->num_nodes;t++) {
648         if(g->nodes[t].tmp<0) {
649             do_dfs(&g->nodes[t], count++);
650         }
651     }
652     return count;
653 }
654
655 #ifdef MAIN
656 int main()
657 {
658     int t;
659     int s;
660     for(s=0;s<10;s++) {
661         int width = (lrand48()%8)+1;
662         graph_t*g = graph_new(width*width);
663         for(t=0;t<width*width;t++) {
664             int x = t%width;
665             int y = t/width;
666             int w = 1;
667 #define R (lrand48()%32)
668             if(x>0) graph_add_edge(&g->nodes[t], &g->nodes[t-1], R, R);
669             if(x<width-1) graph_add_edge(&g->nodes[t], &g->nodes[t+1], R, R);
670             if(y>0) graph_add_edge(&g->nodes[t], &g->nodes[t-width], R, R);
671             if(y<width-1) graph_add_edge(&g->nodes[t], &g->nodes[t+width], R, R);
672         }
673         
674         int x = graph_maxflow(g, &g->nodes[0], &g->nodes[width*width-1]);
675         printf("max flow: %d\n", x);
676         graph_delete(g);
677     }
678 }
679 #endif