]> git.tdb.fi Git - libs/core.git/blob - source/strings/regex.cpp
Rename UnicodeChar to unichar
[libs/core.git] / source / strings / regex.cpp
1 #include <stack>
2 #include <limits>
3 #include <msp/core/except.h>
4 #include "format.h"
5 #include "regex.h"
6
7 using namespace std;
8
9 namespace {
10
11 /** Writes an integer to a Regex code string, in little-endian order. */
12 template<typename T>
13 void write_int(T n, Msp::Regex::Code &code)
14 {
15         for(unsigned i=0; i<sizeof(T); ++i)
16                 code += (n>>i*8)&0xFF;
17 }
18
19 /** Reads an integer from a Regex code stream, in little-endian order. */
20 template<typename T>
21 T read_int(Msp::Regex::Code::const_iterator &c)
22 {
23         T result = 0;
24         for(unsigned i=0; i<sizeof(T); ++i)
25                 result += static_cast<unsigned char>(*c++)<<i*8;
26         return result;
27 }
28
29 }
30
31
32 namespace Msp {
33
34 Regex::Regex(const string &expr)
35 {
36         n_groups = 0;
37         string::const_iterator iter = expr.begin();
38         code = compile(expr, iter, n_groups, false);
39         ++n_groups;
40 }
41
42 Regex::Code Regex::compile(const string &expr, string::const_iterator &iter, unsigned &group, bool branch)
43 {
44         bool has_branches = false;
45         unsigned level = 0;
46         bool escape = false;
47         unsigned bracket = 0;
48         string::const_iterator end;
49         for(end=iter; end!=expr.end(); ++end)
50         {
51                 if(escape)
52                         escape = false;
53                 else if(bracket)
54                 {
55                         if(bracket==3 && *end==']')
56                                 bracket = 0;
57                         else if(bracket==1 && *end=='^')
58                                 bracket = 2;
59                         else
60                                 bracket = 3;
61                 }
62                 else if(*end=='\\')
63                         escape = true;
64                 else if(*end=='(')
65                         ++level;
66                 else if(*end==')')
67                 {
68                         if(level==0)
69                         {
70                                 if(group==0)
71                                         throw InvalidParameterValue("Unexpected )");
72                                 else
73                                         break;
74                         }
75                         --level;
76                 }
77                 else if(*end=='|' && level==0)
78                 {
79                         if(branch)
80                                 break;
81                         else
82                                 has_branches = true;
83                 }
84                 else if(*end=='[')
85                         bracket = 1;
86         }
87
88         if(level>0)
89                 throw InvalidParameterValue("Unmatched (");
90
91         Code result;
92
93         unsigned this_group = group;
94         if(!branch)
95         {
96                 result += GROUP_BEGIN;
97                 write_int<Index>(this_group, result);
98         }
99
100         const unsigned jump_size = 1+sizeof(Offset);
101
102         if(!has_branches)
103         {
104                 for(string::const_iterator i=iter; i!=end;)
105                 {
106                         Code atom = parse_atom(expr, i, group);
107
108                         Count repeat_min = 1;
109                         Count repeat_max = 1;
110                         parse_repeat(i, repeat_min, repeat_max);
111
112                         for(unsigned j=0; j<repeat_min; ++j)
113                                 result += atom;
114                         if(repeat_max==numeric_limits<Count>::max())
115                         {
116                                 if(repeat_min==0)
117                                 {
118                                         result += ND_JUMP;
119                                         write_int<Offset>(atom.size()+jump_size, result);
120                                         result += atom;
121                                 }
122                                 result += ND_JUMP;
123                                 write_int<Offset>(-(atom.size()+jump_size), result);
124                         }
125                         else if(repeat_max>repeat_min)
126                         {
127                                 for(unsigned j=repeat_min; j<repeat_max; ++j)
128                                 {
129                                         result += ND_JUMP;
130                                         write_int<Offset>((repeat_max-j)*(atom.size()+jump_size)-jump_size, result);
131                                         result += atom;
132                                 }
133                         }
134                 }
135         }
136         else
137         {
138                 list<Code> branches;
139                 for(string::const_iterator i=iter;;)
140                 {
141                         branches.push_back(compile(expr, i, group, true));
142                         if(i==end)
143                                 break;
144                         ++i;
145                 }
146
147                 unsigned n_branches = branches.size();
148
149                 Offset offset = (n_branches-1)*jump_size+branches.front().size();
150                 for(list<Code>::iterator i=++branches.begin(); i!=branches.end(); ++i)
151                 {
152                         result += ND_JUMP;
153                         write_int<Offset>(offset, result);
154                         offset += i->size();
155                 }
156
157                 for(list<Code>::iterator i=branches.begin(); i!=branches.end();)
158                 {
159                         result += *i;
160                         offset -= i->size()+jump_size;
161                         ++i;
162                         if(i!=branches.end())
163                         {
164                                 result += JUMP;
165                                 write_int<Offset>(offset, result);
166                         }
167                 }
168         }
169
170         if(!branch)
171         {
172                 result += GROUP_END;
173                 write_int<Index>(this_group, result);
174         }
175
176         iter = end;
177
178         return result;
179 }
180
181 Regex::Code Regex::parse_atom(const string &expr, string::const_iterator &i, unsigned &group)
182 {
183         Code result;
184
185         if(i==expr.end())
186                 return result;
187
188         bool flag = false;
189         if(*i=='\\')
190         {
191                 if(++i==expr.end())
192                         throw InvalidParameterValue("Stray backslash");
193                 flag = true;
194         }
195
196         if(!flag)
197         {
198                 if(*i=='*' || *i=='{' || *i=='}' || *i=='+' || *i=='?' || *i=='|' || *i==')')
199                         throw InvalidParameterValue("Invalid atom");
200                 else if(*i=='[')
201                         return parse_brackets(expr, i);
202                 else if(*i=='.')
203                         result += MATCH_ANY;
204                 else if(*i=='^')
205                         result += MATCH_BEGIN;
206                 else if(*i=='$')
207                         result += MATCH_END;
208                 else if(*i=='(')
209                 {
210                         ++group;
211                         result = compile(expr, ++i, group, false);
212                 }
213                 else
214                         flag = true;
215         }
216
217         if(flag)
218         {
219                 result += MATCH_CHAR;
220                 result += *i;
221         }
222
223         ++i;
224
225         return result;
226 }
227
228 bool Regex::parse_repeat(string::const_iterator &i, Count &rmin, Count &rmax)
229 {
230         if(*i!='*' && *i!='+' && *i!='?' && *i!='{')
231                 return false;
232
233         if(*i=='*' || *i=='+')
234                 rmax = numeric_limits<Count>::max();
235         if(*i=='*' || *i=='?')
236                 rmin = 0;
237         if(*i=='{')
238         {
239                 rmin = 0;
240                 for(++i; isdigit(*i); ++i)
241                         rmin = rmin*10+(*i-'0');
242
243                 if(*i==',')
244                 {
245                         ++i;
246                         if(*i!='}')
247                         {
248                                 rmax = 0;
249                                 for(; isdigit(*i); ++i)
250                                         rmax = rmax*10+(*i-'0');
251                                 if(rmax<rmin)
252                                         throw InvalidParameterValue("Invalid bound");
253                         }
254                         else
255                                 rmax = numeric_limits<Count>::max();
256                 }
257                 else
258                         rmax = rmin;
259                 if(*i!='}')
260                         throw InvalidParameterValue("Invalid bound");
261         }
262
263         ++i;
264
265         return true;
266 }
267
268 Regex::Code Regex::parse_brackets(const string &str, string::const_iterator &iter)
269 {
270         Code result;
271
272         ++iter;
273         bool neg = false;
274         if(*iter=='^')
275         {
276                 neg = true;
277                 ++iter;
278         }
279
280         string::const_iterator end = iter;
281         for(; (end!=str.end() && (end==iter || *end!=']')); ++end) ;
282         if(end==str.end())
283                 throw InvalidParameterValue("Unmatched '['");
284
285         unsigned char mask[32] = {0};
286         unsigned type = 0;
287         bool range = false;
288         unsigned char first=0, last = 0;
289         for(string::const_iterator i=iter; i!=end; ++i)
290         {
291                 unsigned char c = *i;
292                 if(range)
293                 {
294                         last = c;
295                         for(unsigned j=first; j<=c; ++j)
296                                 mask[j>>3] |= 1<<(j&7);
297                         range = false;
298                         if(type<2)
299                                 type = 2;
300                 }
301                 else if(c=='-' && i!=iter && end-i>1)
302                         range = true;
303                 else
304                 {
305                         first = c;
306                         mask[c>>3] |= 1<<(c&7);
307                         if(type==0)
308                                 type = 1;
309                         else
310                                 type = 3;
311                 }
312         }
313
314         if(neg)
315                 result += NEGATE;
316
317         if(type==1)
318         {
319                 result += MATCH_CHAR;
320                 result += first;
321         }
322         else if(type==2)
323         {
324                 result += MATCH_RANGE;
325                 result += first;
326                 result += last;
327         }
328         else
329         {
330                 result += MATCH_MASK;
331                 result.append(reinterpret_cast<char *>(mask), 32);
332         }
333
334         iter = end;
335         ++iter;
336
337         return result;
338 }
339
340 RegMatch Regex::match(const string &str) const
341 {
342         RegMatch::GroupArray groups(n_groups);
343
344         for(string::const_iterator i=str.begin(); i!=str.end(); ++i)
345                 if(run(str, i, groups))
346                         return RegMatch(str, groups);
347
348         return RegMatch();
349 }
350
351 bool Regex::run(const string &str, const string::const_iterator &begin, RegMatch::GroupArray &groups) const
352 {
353         bool result = false;
354         list<RunContext> ctx;
355         ctx.push_back(RunContext());
356         ctx.front().citer = code.begin();
357         ctx.front().groups.resize(groups.size());
358
359         for(string::const_iterator i=begin;;)
360         {
361                 int c;
362                 if(i!=str.end())
363                         c = static_cast<unsigned char>(*i);
364                 else
365                         c = -1;
366
367                 for(list<RunContext>::iterator j=ctx.begin(); j!=ctx.end();)
368                 {
369                         bool terminate = false;
370                         bool negate_match = false;
371                         for(; j->citer!=code.end();)
372                         {
373                                 Instruction instr = static_cast<Instruction>(*j->citer++);
374
375                                 if(instr==NEGATE)
376                                         negate_match = true;
377                                 else if(instr==JUMP)
378                                 {
379                                         Offset offset = read_int<Offset>(j->citer);
380                                         j->citer += offset;
381                                 }
382                                 else if(instr==ND_JUMP)
383                                 {
384                                         Offset offset = read_int<Offset>(j->citer);
385                                         ctx.push_back(*j);
386                                         ctx.back().citer += offset;
387                                 }
388                                 else if(instr==GROUP_BEGIN)
389                                 {
390                                         Index n = read_int<Index>(j->citer);
391                                         if(!j->groups[n].match)
392                                                 j->groups[n].begin = i-str.begin();
393                                 }
394                                 else if(instr==GROUP_END)
395                                 {
396                                         Index n = read_int<Index>(j->citer);
397                                         if(!j->groups[n].match)
398                                         {
399                                                 j->groups[n].match = true;
400                                                 j->groups[n].end = i-str.begin();
401                                                 j->groups[n].length = j->groups[n].end-j->groups[n].begin;
402                                         }
403
404                                         if(n==0)
405                                         {
406                                                 result = true;
407                                                 bool better = false;
408                                                 for(unsigned k=0; (k<groups.size() && !better); ++k)
409                                                 {
410                                                         better = group_compare(j->groups[k], groups[k]);
411                                                         if(group_compare(groups[k], j->groups[k]))
412                                                                 break;
413                                                 }
414                                                 if(better)
415                                                         groups = j->groups;
416                                         }
417                                 }
418                                 else
419                                 {
420                                         bool match_result = false;
421                                         bool input_consumed = false;
422                                         if(instr==MATCH_BEGIN)
423                                                 match_result = (i==str.begin());
424                                         else if(instr==MATCH_END)
425                                                 match_result = (i==str.end());
426                                         else if(instr==MATCH_CHAR)
427                                         {
428                                                 match_result = (c==*j->citer++);
429                                                 input_consumed = true;
430                                         }
431                                         else if(instr==MATCH_RANGE)
432                                         {
433                                                 unsigned char first = *j->citer++;
434                                                 unsigned char last = *j->citer++;
435                                                 match_result = (c>=first && c<=last);
436                                                 input_consumed = true;
437                                         }
438                                         else if(instr==MATCH_MASK)
439                                         {
440                                                 if(c>=0 && c<=0xFF)
441                                                 {
442                                                         unsigned char m = *(j->citer+(c>>3));
443                                                         match_result = m&(1<<(c&7));
444                                                 }
445                                                 input_consumed = true;
446                                                 j->citer += 32;
447                                         }
448                                         else if(instr==MATCH_ANY)
449                                         {
450                                                 match_result = true;
451                                                 input_consumed = true;
452                                         }
453                                         else
454                                                 throw Exception("Invalid instruction");
455
456                                         if(match_result==negate_match)
457                                                 terminate = true;
458                                         negate_match = false;
459
460                                         if(input_consumed || terminate)
461                                                 break;
462                                 }
463                         }
464
465                         if(terminate || j->citer==code.end())
466                                 j = ctx.erase(j);
467                         else
468                                 ++j;
469                 }
470
471                 if(i==str.end() || ctx.empty())
472                         break;
473                 ++i;
474         }
475
476         return result;
477 }
478
479 bool Regex::group_compare(const RegMatch::Group &g1, const RegMatch::Group &g2) const
480 {
481         if(!g1.match)
482                 return false;
483
484         // Any match is better than no match
485         if(!g2.match)
486                 return true;
487
488         // Earlier match is better
489         if(g1.begin<g2.begin)
490                 return true;
491         if(g2.begin>g2.begin)
492                 return false;
493
494         // Longer match at same position is better
495         return g1.end>g2.end;
496 }
497
498 string Regex::disassemble() const
499 {
500         ostringstream ss;
501
502         for(Code::const_iterator i=code.begin(); i!=code.end();)
503         {
504                 Code::const_iterator j = i;
505                 Offset offset = i-code.begin();
506                 string decompiled = disassemble_instruction(i);
507                 string bytes;
508                 for(; j!=i; ++j)
509                         bytes += format(" %02X", static_cast<int>(*j)&0xFF);
510                 ss<<Fmt("%3d")<<offset<<':'<<Fmt("%-9s")<<bytes;
511                 if(bytes.size()>9)
512                         ss<<"\n"<<Fmt("%15s");
513                 ss<<"  "<<decompiled<<'\n';
514         }
515
516         return ss.str();
517 }
518
519 string Regex::disassemble_instruction(Code::const_iterator &i) const
520 {
521         Instruction instr = static_cast<Instruction>(*i++);
522
523         ostringstream result;
524         switch(instr)
525         {
526         case JUMP:
527                 {
528                         Offset offset = read_int<Offset>(i);
529                         result<<"JUMP "<<Fmt("%+d")<<offset<<" ("<<Fmt("%d")<<i-code.begin()+offset<<')';
530                 }
531                 break;
532         case ND_JUMP:
533                 {
534                         Offset offset = read_int<Offset>(i);
535                         result<<"ND_JUMP "<<Fmt("%+d")<<offset<<" ("<<Fmt("%d")<<i-code.begin()+offset<<')';
536                 }
537                 break;
538         case GROUP_BEGIN:
539                 result<<"GROUP_BEGIN "<<read_int<Index>(i);
540                 break;
541         case GROUP_END:
542                 result<<"GROUP_END "<<read_int<Index>(i);
543                 break;
544         case NEGATE:
545                 result<<"NEGATE";
546                 break;
547         case MATCH_BEGIN:
548                 result<<"MATCH_BEGIN";
549                 break;
550         case MATCH_END:
551                 result<<"MATCH_END";
552                 break;
553         case MATCH_CHAR:
554                 {
555                         char c = *i++;
556                         result<<"MATCH_CHAR ";
557                         if(c>=0x20 && c<=0x7E)
558                                 result<<'\''<<c<<'\'';
559                         else
560                                 result<<(static_cast<int>(c)&0xFF);
561                 }
562                 break;
563         case MATCH_RANGE:
564                 result<<"MATCH_RANGE "<<(static_cast<int>(*i++)&0xFF);
565                 result<<'-'<<(static_cast<int>(*i++)&0xFF);
566                 break;
567         case MATCH_MASK:
568                 result<<"MATCH_MASK";
569                 for(unsigned j=0; j<32; ++j)
570                         result<<' '<<Fmt("%02X")<<(static_cast<int>(*i++)&0xFF);
571                 break;
572         case MATCH_ANY:
573                 result<<"MATCH_ANY";
574                 break;
575         default:
576                 result<<"UNKNOWN "<<instr;
577         }
578
579         return result.str();
580 }
581
582 } // namespace Msp