60b8ed10294aa1e83b399cb3605a3ef9750b597b
[libs/core.git] / source / strings / regex.cpp
1 #include <limits>
2 #include <list>
3 #include <stack>
4 #include "format.h"
5 #include "regex.h"
6
7 using namespace std;
8
9 namespace {
10
11 /** Writes an integer to a Regex code string, in little-endian order. */
12 template<typename T>
13 void write_int(T n, Msp::Regex::Code &code)
14 {
15         for(unsigned i=0; i<sizeof(T); ++i)
16                 code += (n>>(i*8))&0xFF;
17 }
18
19 /** Reads an integer from a Regex code string, in little-endian order. */
20 template<typename T>
21 T read_int(Msp::Regex::Code::const_iterator &c)
22 {
23         T result = 0;
24         for(unsigned i=0; i<sizeof(T); ++i)
25                 result += (*c++)<<(i*8);
26         return result;
27 }
28
29 }
30
31
32 namespace Msp {
33
34 bad_regex::bad_regex(const string &w, const string &e, const string::const_iterator &i):
35         logic_error(w+"\n"+make_where(e, i))
36 { }
37
38 string bad_regex::make_where(const string &e, const string::const_iterator &i)
39 {
40         string result;
41         string::size_type offset = i-e.begin();
42         if(offset>40)
43         {
44                 result = e.substr(offset-40, 60);
45                 offset = 40;
46         }
47         else
48                 result = e.substr(0, 60);
49         result += '\n';
50         result.append(offset, ' ');
51         result += '^';
52         return result;
53 }
54
55
56 Regex::Regex(const string &expr)
57 {
58         n_groups = 0;
59         string::const_iterator iter = expr.begin();
60         code = compile(expr, iter, n_groups, false);
61         ++n_groups;
62 }
63
64 Regex::Code Regex::compile(const string &expr, string::const_iterator &iter, unsigned &group, bool branch)
65 {
66         bool has_branches = false;
67         stack<string::const_iterator> parens;
68         bool escape = false;
69         unsigned bracket = 0;
70         string::const_iterator end;
71         for(end=iter; end!=expr.end(); ++end)
72         {
73                 if(escape)
74                         escape = false;
75                 else if(bracket)
76                 {
77                         if(bracket==3 && *end==']')
78                                 bracket = 0;
79                         else if(bracket==1 && *end=='^')
80                                 bracket = 2;
81                         else
82                                 bracket = 3;
83                 }
84                 else if(*end=='\\')
85                         escape = true;
86                 else if(*end=='(')
87                         parens.push(end);
88                 else if(*end==')')
89                 {
90                         if(parens.empty())
91                         {
92                                 if(group==0)
93                                         throw bad_regex("unmatched ')'", expr, end);
94                                 else
95                                         break;
96                         }
97                         parens.pop();
98                 }
99                 else if(*end=='|' && parens.empty())
100                 {
101                         if(branch)
102                                 break;
103                         else
104                                 has_branches = true;
105                 }
106                 else if(*end=='[')
107                         bracket = 1;
108         }
109
110         if(!parens.empty())
111                 throw bad_regex("unmatched '('", expr, parens.top());
112
113         Code result;
114
115         unsigned this_group = group;
116         if(!branch)
117         {
118                 result += GROUP_BEGIN;
119                 write_int<Index>(this_group, result);
120         }
121
122         const unsigned jump_size = 1+sizeof(Offset);
123
124         if(!has_branches)
125         {
126                 for(string::const_iterator i=iter; i!=end;)
127                 {
128                         Code atom = parse_atom(expr, i, group);
129
130                         Count repeat_min = 1;
131                         Count repeat_max = 1;
132                         parse_repeat(expr, i, repeat_min, repeat_max);
133
134                         for(unsigned j=0; j<repeat_min; ++j)
135                                 result += atom;
136                         if(repeat_max==numeric_limits<Count>::max())
137                         {
138                                 if(repeat_min==0)
139                                 {
140                                         result += ND_JUMP;
141                                         write_int<Offset>(atom.size()+jump_size, result);
142                                         result += atom;
143                                 }
144                                 result += ND_JUMP;
145                                 write_int<Offset>(-(atom.size()+jump_size), result);
146                         }
147                         else if(repeat_max>repeat_min)
148                         {
149                                 for(unsigned j=repeat_min; j<repeat_max; ++j)
150                                 {
151                                         result += ND_JUMP;
152                                         write_int<Offset>((repeat_max-j)*(atom.size()+jump_size)-jump_size, result);
153                                         result += atom;
154                                 }
155                         }
156                 }
157         }
158         else
159         {
160                 list<Code> branches;
161                 for(string::const_iterator i=iter;;)
162                 {
163                         branches.push_back(compile(expr, i, group, true));
164                         if(i==end)
165                                 break;
166                         ++i;
167                 }
168
169                 unsigned n_branches = branches.size();
170
171                 Offset offset = (n_branches-1)*jump_size+branches.front().size();
172                 for(list<Code>::iterator i=++branches.begin(); i!=branches.end(); ++i)
173                 {
174                         result += ND_JUMP;
175                         write_int<Offset>(offset, result);
176                         offset += i->size();
177                 }
178
179                 for(list<Code>::iterator i=branches.begin(); i!=branches.end();)
180                 {
181                         result += *i;
182                         offset -= i->size()+jump_size;
183                         ++i;
184                         if(i!=branches.end())
185                         {
186                                 result += JUMP;
187                                 write_int<Offset>(offset, result);
188                         }
189                 }
190         }
191
192         if(!branch)
193         {
194                 result += GROUP_END;
195                 write_int<Index>(this_group, result);
196         }
197
198         iter = end;
199
200         return result;
201 }
202
203 Regex::Code Regex::parse_atom(const string &expr, string::const_iterator &i, unsigned &group)
204 {
205         Code result;
206
207         if(i==expr.end())
208                 return result;
209
210         bool flag = false;
211         if(*i=='\\')
212         {
213                 if(++i==expr.end())
214                         throw bad_regex("stray backslash", expr, i-1);
215                 flag = true;
216         }
217
218         if(!flag)
219         {
220                 if(*i=='*' || *i=='{' || *i=='}' || *i=='+' || *i=='?' || *i=='|' || *i==')')
221                         throw bad_regex("invalid atom", expr, i);
222                 else if(*i=='[')
223                         return parse_brackets(expr, i);
224                 else if(*i=='.')
225                         result += MATCH_ANY;
226                 else if(*i=='^')
227                         result += MATCH_BEGIN;
228                 else if(*i=='$')
229                         result += MATCH_END;
230                 else if(*i=='(')
231                 {
232                         ++group;
233                         result = compile(expr, ++i, group, false);
234                 }
235                 else
236                         flag = true;
237         }
238
239         if(flag)
240         {
241                 result += MATCH_CHAR;
242                 result += *i;
243         }
244
245         ++i;
246
247         return result;
248 }
249
250 bool Regex::parse_repeat(const string &expr, string::const_iterator &i, Count &rmin, Count &rmax)
251 {
252         if(*i!='*' && *i!='+' && *i!='?' && *i!='{')
253                 return false;
254
255         if(*i=='*' || *i=='+')
256                 rmax = numeric_limits<Count>::max();
257         if(*i=='*' || *i=='?')
258                 rmin = 0;
259         if(*i=='{')
260         {
261                 string::const_iterator begin = i;
262
263                 rmin = 0;
264                 for(++i; isdigit(*i); ++i)
265                         rmin = rmin*10+(*i-'0');
266
267                 if(*i==',')
268                 {
269                         ++i;
270                         if(*i!='}')
271                         {
272                                 rmax = 0;
273                                 for(; isdigit(*i); ++i)
274                                         rmax = rmax*10+(*i-'0');
275                                 if(rmax<rmin)
276                                         throw bad_regex("invalid bound", expr, begin);
277                         }
278                         else
279                                 rmax = numeric_limits<Count>::max();
280                 }
281                 else
282                         rmax = rmin;
283                 if(*i!='}')
284                         throw bad_regex("invalid bound", expr, begin);
285         }
286
287         ++i;
288
289         return true;
290 }
291
292 Regex::Code Regex::parse_brackets(const string &str, string::const_iterator &iter)
293 {
294         string::const_iterator begin = iter;
295         Code result;
296
297         ++iter;
298         bool neg = false;
299         if(*iter=='^')
300         {
301                 neg = true;
302                 ++iter;
303         }
304
305         string::const_iterator end = iter;
306         for(; (end!=str.end() && (end==iter || *end!=']')); ++end) ;
307         if(end==str.end())
308                 throw bad_regex("unmatched '['", str, begin);
309
310         unsigned char mask[32] = {0};
311         unsigned type = 0;
312         bool range = false;
313         unsigned char first=0, last = 0;
314         for(string::const_iterator i=iter; i!=end; ++i)
315         {
316                 unsigned char c = *i;
317                 if(range)
318                 {
319                         last = c;
320                         for(unsigned j=first; j<=c; ++j)
321                                 mask[j>>3] |= 1<<(j&7);
322                         range = false;
323                         if(type<2)
324                                 type = 2;
325                 }
326                 else if(c=='-' && i!=iter && end-i>1)
327                         range = true;
328                 else
329                 {
330                         first = c;
331                         mask[c>>3] |= 1<<(c&7);
332                         if(type==0)
333                                 type = 1;
334                         else
335                                 type = 3;
336                 }
337         }
338
339         if(neg)
340                 result += NEGATE;
341
342         if(type==1)
343         {
344                 result += MATCH_CHAR;
345                 result += first;
346         }
347         else if(type==2)
348         {
349                 result += MATCH_RANGE;
350                 result += first;
351                 result += last;
352         }
353         else
354         {
355                 result += MATCH_MASK;
356                 result.append(mask, 32);
357         }
358
359         iter = end;
360         ++iter;
361
362         return result;
363 }
364
365 RegMatch Regex::match(const string &str) const
366 {
367         RegMatch::GroupArray groups(n_groups);
368
369         for(string::const_iterator i=str.begin(); i!=str.end(); ++i)
370                 if(run(str, i, groups))
371                         return RegMatch(str, groups);
372
373         return RegMatch();
374 }
375
376 bool Regex::run(const string &str, const string::const_iterator &begin, RegMatch::GroupArray &groups) const
377 {
378         bool result = false;
379         list<RunContext> ctx;
380         ctx.push_back(RunContext());
381         ctx.front().citer = code.begin();
382         ctx.front().groups.resize(groups.size());
383
384         for(string::const_iterator i=begin;;)
385         {
386                 int c;
387                 if(i!=str.end())
388                         c = *i;
389                 else
390                         c = -1;
391
392                 for(list<RunContext>::iterator j=ctx.begin(); j!=ctx.end();)
393                 {
394                         bool terminate = false;
395                         bool negate_match = false;
396                         for(; j->citer!=code.end();)
397                         {
398                                 Instruction instr = static_cast<Instruction>(*j->citer++);
399
400                                 if(instr==NEGATE)
401                                         negate_match = true;
402                                 else if(instr==JUMP)
403                                 {
404                                         Offset offset = read_int<Offset>(j->citer);
405                                         j->citer += offset;
406                                 }
407                                 else if(instr==ND_JUMP)
408                                 {
409                                         Offset offset = read_int<Offset>(j->citer);
410                                         ctx.push_back(*j);
411                                         ctx.back().citer += offset;
412                                 }
413                                 else if(instr==GROUP_BEGIN)
414                                 {
415                                         Index n = read_int<Index>(j->citer);
416                                         if(!j->groups[n].match)
417                                                 j->groups[n].begin = i-str.begin();
418                                 }
419                                 else if(instr==GROUP_END)
420                                 {
421                                         Index n = read_int<Index>(j->citer);
422                                         if(!j->groups[n].match)
423                                         {
424                                                 j->groups[n].match = true;
425                                                 j->groups[n].end = i-str.begin();
426                                                 j->groups[n].length = j->groups[n].end-j->groups[n].begin;
427                                         }
428
429                                         if(n==0)
430                                         {
431                                                 result = true;
432                                                 bool better = false;
433                                                 for(unsigned k=0; (k<groups.size() && !better); ++k)
434                                                 {
435                                                         better = group_compare(j->groups[k], groups[k]);
436                                                         if(group_compare(groups[k], j->groups[k]))
437                                                                 break;
438                                                 }
439                                                 if(better)
440                                                         groups = j->groups;
441                                         }
442                                 }
443                                 else
444                                 {
445                                         bool match_result = false;
446                                         bool input_consumed = false;
447                                         if(instr==MATCH_BEGIN)
448                                                 match_result = (i==str.begin());
449                                         else if(instr==MATCH_END)
450                                                 match_result = (i==str.end());
451                                         else if(instr==MATCH_CHAR)
452                                         {
453                                                 match_result = (c==*j->citer++);
454                                                 input_consumed = true;
455                                         }
456                                         else if(instr==MATCH_RANGE)
457                                         {
458                                                 unsigned char first = *j->citer++;
459                                                 unsigned char last = *j->citer++;
460                                                 match_result = (c>=first && c<=last);
461                                                 input_consumed = true;
462                                         }
463                                         else if(instr==MATCH_MASK)
464                                         {
465                                                 if(c>=0 && c<=0xFF)
466                                                 {
467                                                         unsigned char m = *(j->citer+(c>>3));
468                                                         match_result = m&(1<<(c&7));
469                                                 }
470                                                 input_consumed = true;
471                                                 j->citer += 32;
472                                         }
473                                         else if(instr==MATCH_ANY)
474                                         {
475                                                 match_result = true;
476                                                 input_consumed = true;
477                                         }
478                                         else
479                                                 throw logic_error("invalid instruction in regex bytecode");
480
481                                         if(match_result==negate_match)
482                                                 terminate = true;
483                                         negate_match = false;
484
485                                         if(input_consumed || terminate)
486                                                 break;
487                                 }
488                         }
489
490                         if(terminate || j->citer==code.end())
491                                 j = ctx.erase(j);
492                         else
493                                 ++j;
494                 }
495
496                 if(i==str.end() || ctx.empty())
497                         break;
498                 ++i;
499         }
500
501         return result;
502 }
503
504 bool Regex::group_compare(const RegMatch::Group &g1, const RegMatch::Group &g2) const
505 {
506         if(!g1.match)
507                 return false;
508
509         // Any match is better than no match
510         if(!g2.match)
511                 return true;
512
513         // Earlier match is better
514         if(g1.begin<g2.begin)
515                 return true;
516         if(g2.begin>g2.begin)
517                 return false;
518
519         // Longer match at same position is better
520         return g1.end>g2.end;
521 }
522
523 string Regex::disassemble() const
524 {
525         string result;
526
527         for(Code::const_iterator i=code.begin(); i!=code.end();)
528         {
529                 Code::const_iterator j = i;
530                 Offset offset = i-code.begin();
531                 string decompiled = disassemble_instruction(i);
532                 string bytes;
533                 for(; j!=i; ++j)
534                         bytes += format(" %02X", *j);
535                 result += format("%3d:%-9s  ", offset, bytes);
536                 if(bytes.size()>9)
537                         result += "\n               ";
538                 result += decompiled;
539                 result += '\n';
540         }
541
542         return result;
543 }
544
545 string Regex::disassemble_instruction(Code::const_iterator &i) const
546 {
547         Instruction instr = static_cast<Instruction>(*i++);
548
549         switch(instr)
550         {
551         case JUMP:
552                 {
553                         Offset offset = read_int<Offset>(i);
554                         return format("JUMP %+d (%d)", offset, (i-code.begin())+offset);
555                 }
556         case ND_JUMP:
557                 {
558                         Offset offset = read_int<Offset>(i);
559                         return format("ND_JUMP %+d (%d)", offset, (i-code.begin())+offset);
560                 }
561         case GROUP_BEGIN:
562                 return format("GROUP_BEGIN %d", read_int<Index>(i));
563         case GROUP_END:
564                 return format("GROUP_END %d", read_int<Index>(i));
565         case NEGATE:
566                 return "NEGATE";
567         case MATCH_BEGIN:
568                 return "MATCH_BEGIN";
569         case MATCH_END:
570                 return "MATCH_END";
571         case MATCH_CHAR:
572                 {
573                         unsigned char c = *i++;
574                         if(c>=0x20 && c<=0x7E)
575                                 return format("MATCH_CHAR '%c'", c);
576                         else
577                                 return format("MATCH_CHAR %d", c);
578                 }
579                 break;
580         case MATCH_RANGE:
581                 {
582                         int begin = *i++;
583                         int end = *i++;
584                         return format("MATCH_RANGE %d-%d", begin, end);
585                 }
586         case MATCH_MASK:
587                 {
588                         string result = "MATCH_MASK";
589                         for(unsigned j=0; j<32; ++j)
590                                 result += format(" %02X", *i++);
591                         return result;
592                 }
593         case MATCH_ANY:
594                 return "MATCH_ANY";
595         default:
596                 return format("UNKNOWN %d", instr);
597         }
598 }
599
600 } // namespace Msp