]> git.tdb.fi Git - libs/core.git/blob - source/strings/regex.cpp
Don't access iterator if it's at end
[libs/core.git] / source / strings / regex.cpp
1 #include <limits>
2 #include <list>
3 #include <stack>
4 #include <vector>
5 #include "format.h"
6 #include "regex.h"
7
8 using namespace std;
9
10 namespace {
11
12 /** Writes an integer to a Regex code string, in little-endian order. */
13 template<typename T>
14 void write_int(T n, basic_string<unsigned char> &code)
15 {
16         for(unsigned i=0; i<sizeof(T); ++i)
17                 code += (n>>(i*8))&0xFF;
18 }
19
20 /** Reads an integer from a Regex code string, in little-endian order. */
21 template<typename T>
22 T read_int(basic_string<unsigned char>::const_iterator &c)
23 {
24         T result = 0;
25         for(unsigned i=0; i<sizeof(T); ++i)
26                 result += (*c++)<<(i*8);
27         return result;
28 }
29
30 }
31
32
33 namespace Msp {
34
35 bad_regex::bad_regex(const string &w, const string &e, const string::const_iterator &i):
36         logic_error(w+"\n"+make_where(e, i))
37 { }
38
39 string bad_regex::make_where(const string &e, const string::const_iterator &i)
40 {
41         string result;
42         string::size_type offset = i-e.begin();
43         if(offset>40)
44         {
45                 result = e.substr(offset-40, 60);
46                 offset = 40;
47         }
48         else
49                 result = e.substr(0, 60);
50         result += '\n';
51         result.append(offset, ' ');
52         result += '^';
53         return result;
54 }
55
56
57 Regex::Regex(const string &expr)
58 {
59         n_groups = 0;
60         auto iter = expr.begin();
61         code = compile(expr, iter, n_groups, false);
62         ++n_groups;
63 }
64
65 Regex::Code Regex::compile(const string &expr, string::const_iterator &iter, unsigned &group, bool branch)
66 {
67         bool has_branches = false;
68         stack<string::const_iterator> parens;
69         bool escape = false;
70         unsigned bracket = 0;
71         string::const_iterator end;
72         for(end=iter; end!=expr.end(); ++end)
73         {
74                 if(escape)
75                         escape = false;
76                 else if(bracket)
77                 {
78                         if(bracket==3 && *end==']')
79                                 bracket = 0;
80                         else if(bracket==1 && *end=='^')
81                                 bracket = 2;
82                         else
83                                 bracket = 3;
84                 }
85                 else if(*end=='\\')
86                         escape = true;
87                 else if(*end=='(')
88                         parens.push(end);
89                 else if(*end==')')
90                 {
91                         if(parens.empty())
92                         {
93                                 if(group==0)
94                                         throw bad_regex("unmatched ')'", expr, end);
95                                 else
96                                         break;
97                         }
98                         parens.pop();
99                 }
100                 else if(*end=='|' && parens.empty())
101                 {
102                         if(branch)
103                                 break;
104                         else
105                                 has_branches = true;
106                 }
107                 else if(*end=='[')
108                         bracket = 1;
109         }
110
111         if(!parens.empty())
112                 throw bad_regex("unmatched '('", expr, parens.top());
113
114         Code result;
115
116         unsigned this_group = group;
117         if(!branch)
118         {
119                 result += GROUP_BEGIN;
120                 write_int<Index>(this_group, result);
121         }
122
123         const unsigned jump_size = 1+sizeof(Offset);
124
125         if(!has_branches)
126         {
127                 for(auto i=iter; i!=end;)
128                 {
129                         Code atom = parse_atom(expr, i, group);
130
131                         Count repeat_min = 1;
132                         Count repeat_max = 1;
133                         if(i!=end)
134                                 parse_repeat(expr, i, repeat_min, repeat_max);
135
136                         for(unsigned j=0; j<repeat_min; ++j)
137                                 result += atom;
138                         if(repeat_max==numeric_limits<Count>::max())
139                         {
140                                 if(repeat_min==0)
141                                 {
142                                         result += ND_JUMP;
143                                         write_int<Offset>(atom.size()+jump_size, result);
144                                         result += atom;
145                                 }
146                                 result += ND_JUMP;
147                                 write_int<Offset>(-static_cast<Offset>(atom.size()+jump_size), result);
148                         }
149                         else if(repeat_max>repeat_min)
150                         {
151                                 for(unsigned j=repeat_min; j<repeat_max; ++j)
152                                 {
153                                         result += ND_JUMP;
154                                         write_int<Offset>((repeat_max-j)*(atom.size()+jump_size)-jump_size, result);
155                                         result += atom;
156                                 }
157                         }
158                 }
159         }
160         else
161         {
162                 vector<Code> branches;
163                 for(auto i=iter;;)
164                 {
165                         branches.push_back(compile(expr, i, group, true));
166                         if(i==end)
167                                 break;
168                         ++i;
169                 }
170
171                 unsigned n_branches = branches.size();
172
173                 Offset offset = (n_branches-1)*jump_size+branches.front().size();
174                 for(auto i=++branches.begin(); i!=branches.end(); ++i)
175                 {
176                         result += ND_JUMP;
177                         write_int<Offset>(offset, result);
178                         offset += i->size();
179                 }
180
181                 for(auto i=branches.begin(); i!=branches.end();)
182                 {
183                         result += *i;
184                         offset -= i->size()+jump_size;
185                         ++i;
186                         if(i!=branches.end())
187                         {
188                                 result += JUMP;
189                                 write_int<Offset>(offset, result);
190                         }
191                 }
192         }
193
194         if(!branch)
195         {
196                 result += GROUP_END;
197                 write_int<Index>(this_group, result);
198         }
199
200         iter = end;
201
202         return result;
203 }
204
205 Regex::Code Regex::parse_atom(const string &expr, string::const_iterator &i, unsigned &group)
206 {
207         Code result;
208
209         if(i==expr.end())
210                 return result;
211
212         bool flag = false;
213         if(*i=='\\')
214         {
215                 if(++i==expr.end())
216                         throw bad_regex("stray backslash", expr, i-1);
217                 flag = true;
218         }
219
220         if(!flag)
221         {
222                 if(*i=='*' || *i=='{' || *i=='}' || *i=='+' || *i=='?' || *i=='|' || *i==')')
223                         throw bad_regex("invalid atom", expr, i);
224                 else if(*i=='[')
225                         return parse_brackets(expr, i);
226                 else if(*i=='.')
227                         result += MATCH_ANY;
228                 else if(*i=='^')
229                         result += MATCH_BEGIN;
230                 else if(*i=='$')
231                         result += MATCH_END;
232                 else if(*i=='(')
233                 {
234                         ++group;
235                         result = compile(expr, ++i, group, false);
236                 }
237                 else
238                         flag = true;
239         }
240
241         if(flag)
242         {
243                 result += MATCH_CHAR;
244                 result += *i;
245         }
246
247         ++i;
248
249         return result;
250 }
251
252 bool Regex::parse_repeat(const string &expr, string::const_iterator &i, Count &rmin, Count &rmax)
253 {
254         if(*i!='*' && *i!='+' && *i!='?' && *i!='{')
255                 return false;
256
257         if(*i=='*' || *i=='+')
258                 rmax = numeric_limits<Count>::max();
259         if(*i=='*' || *i=='?')
260                 rmin = 0;
261         if(*i=='{')
262         {
263                 auto begin = i;
264
265                 rmin = 0;
266                 for(++i; isdigit(*i); ++i)
267                         rmin = rmin*10+(*i-'0');
268
269                 if(*i==',')
270                 {
271                         ++i;
272                         if(*i!='}')
273                         {
274                                 rmax = 0;
275                                 for(; isdigit(*i); ++i)
276                                         rmax = rmax*10+(*i-'0');
277                                 if(rmax<rmin)
278                                         throw bad_regex("invalid bound", expr, begin);
279                         }
280                         else
281                                 rmax = numeric_limits<Count>::max();
282                 }
283                 else
284                         rmax = rmin;
285                 if(*i!='}')
286                         throw bad_regex("invalid bound", expr, begin);
287         }
288
289         ++i;
290
291         return true;
292 }
293
294 Regex::Code Regex::parse_brackets(const string &str, string::const_iterator &iter)
295 {
296         auto begin = iter;
297         Code result;
298
299         ++iter;
300         bool neg = false;
301         if(*iter=='^')
302         {
303                 neg = true;
304                 ++iter;
305         }
306
307         auto end = iter;
308         for(; (end!=str.end() && (end==iter || *end!=']')); ++end) ;
309         if(end==str.end())
310                 throw bad_regex("unmatched '['", str, begin);
311
312         unsigned char mask[32] = {0};
313         unsigned type = 0;
314         bool range = false;
315         unsigned char first = 0, last = 0;
316         for(auto i=iter; i!=end; ++i)
317         {
318                 unsigned char c = *i;
319                 if(range)
320                 {
321                         last = c;
322                         for(unsigned j=first; j<=c; ++j)
323                                 mask[j>>3] |= 1<<(j&7);
324                         range = false;
325                         if(type<2)
326                                 type = 2;
327                 }
328                 else if(c=='-' && i!=iter && end-i>1)
329                         range = true;
330                 else
331                 {
332                         first = c;
333                         mask[c>>3] |= 1<<(c&7);
334                         if(type==0)
335                                 type = 1;
336                         else
337                                 type = 3;
338                 }
339         }
340
341         if(neg)
342                 result += NEGATE;
343
344         if(type==1)
345         {
346                 result += MATCH_CHAR;
347                 result += first;
348         }
349         else if(type==2)
350         {
351                 result += MATCH_RANGE;
352                 result += first;
353                 result += last;
354         }
355         else
356         {
357                 result += MATCH_MASK;
358                 result.append(mask, 32);
359         }
360
361         iter = end;
362         ++iter;
363
364         return result;
365 }
366
367 RegMatch Regex::match(const string &str) const
368 {
369         vector<RegMatch::Group> groups(n_groups);
370
371         for(auto i=str.begin(); i!=str.end(); ++i)
372                 if(run(str, i, groups))
373                         return RegMatch(str, groups);
374
375         return RegMatch();
376 }
377
378 bool Regex::run(const string &str, const string::const_iterator &begin, vector<RegMatch::Group> &groups) const
379 {
380         bool result = false;
381         list<RunContext> ctx;
382         ctx.push_back(RunContext());
383         ctx.front().citer = code.begin();
384         ctx.front().groups.resize(groups.size());
385
386         for(auto i=begin;;)
387         {
388                 int c;
389                 if(i!=str.end())
390                         c = *i;
391                 else
392                         c = -1;
393
394                 for(auto j=ctx.begin(); j!=ctx.end();)
395                 {
396                         bool terminate = false;
397                         bool negate_match = false;
398                         for(; j->citer!=code.end();)
399                         {
400                                 Instruction instr = static_cast<Instruction>(*j->citer++);
401
402                                 if(instr==NEGATE)
403                                         negate_match = true;
404                                 else if(instr==JUMP)
405                                 {
406                                         Offset offset = read_int<Offset>(j->citer);
407                                         j->citer += offset;
408                                 }
409                                 else if(instr==ND_JUMP)
410                                 {
411                                         Offset offset = read_int<Offset>(j->citer);
412                                         ctx.push_back(*j);
413                                         ctx.back().citer += offset;
414                                 }
415                                 else if(instr==GROUP_BEGIN)
416                                 {
417                                         Index n = read_int<Index>(j->citer);
418                                         if(!j->groups[n].match)
419                                                 j->groups[n].begin = i-str.begin();
420                                 }
421                                 else if(instr==GROUP_END)
422                                 {
423                                         Index n = read_int<Index>(j->citer);
424                                         if(!j->groups[n].match)
425                                         {
426                                                 j->groups[n].match = true;
427                                                 j->groups[n].end = i-str.begin();
428                                                 j->groups[n].length = j->groups[n].end-j->groups[n].begin;
429                                         }
430
431                                         if(n==0)
432                                         {
433                                                 result = true;
434                                                 bool better = false;
435                                                 for(unsigned k=0; (k<groups.size() && !better); ++k)
436                                                 {
437                                                         better = group_compare(j->groups[k], groups[k]);
438                                                         if(group_compare(groups[k], j->groups[k]))
439                                                                 break;
440                                                 }
441                                                 if(better)
442                                                         groups = j->groups;
443                                         }
444                                 }
445                                 else
446                                 {
447                                         bool match_result = false;
448                                         bool input_consumed = false;
449                                         if(instr==MATCH_BEGIN)
450                                                 match_result = (i==str.begin());
451                                         else if(instr==MATCH_END)
452                                                 match_result = (i==str.end());
453                                         else if(instr==MATCH_CHAR)
454                                         {
455                                                 match_result = (c==*j->citer++);
456                                                 input_consumed = true;
457                                         }
458                                         else if(instr==MATCH_RANGE)
459                                         {
460                                                 unsigned char first = *j->citer++;
461                                                 unsigned char last = *j->citer++;
462                                                 match_result = (c>=first && c<=last);
463                                                 input_consumed = true;
464                                         }
465                                         else if(instr==MATCH_MASK)
466                                         {
467                                                 if(c>=0 && c<=0xFF)
468                                                 {
469                                                         unsigned char m = *(j->citer+(c>>3));
470                                                         match_result = m&(1<<(c&7));
471                                                 }
472                                                 input_consumed = true;
473                                                 j->citer += 32;
474                                         }
475                                         else if(instr==MATCH_ANY)
476                                         {
477                                                 match_result = true;
478                                                 input_consumed = true;
479                                         }
480                                         else
481                                                 throw logic_error("invalid instruction in regex bytecode");
482
483                                         if(match_result==negate_match)
484                                                 terminate = true;
485                                         negate_match = false;
486
487                                         if(input_consumed || terminate)
488                                                 break;
489                                 }
490                         }
491
492                         if(terminate || j->citer==code.end())
493                                 j = ctx.erase(j);
494                         else
495                                 ++j;
496                 }
497
498                 if(i==str.end() || ctx.empty())
499                         break;
500                 ++i;
501         }
502
503         return result;
504 }
505
506 bool Regex::group_compare(const RegMatch::Group &g1, const RegMatch::Group &g2) const
507 {
508         if(!g1.match)
509                 return false;
510
511         // Any match is better than no match
512         if(!g2.match)
513                 return true;
514
515         // Earlier match is better
516         if(g1.begin<g2.begin)
517                 return true;
518         if(g1.begin>g2.begin)
519                 return false;
520
521         // Longer match at same position is better
522         return g1.end>g2.end;
523 }
524
525 string Regex::disassemble() const
526 {
527         string result;
528
529         for(auto i=code.begin(); i!=code.end();)
530         {
531                 auto j = i;
532                 Offset offset = i-code.begin();
533                 string decompiled = disassemble_instruction(i);
534                 string bytes;
535                 for(; j!=i; ++j)
536                         bytes += format(" %02X", *j);
537                 result += format("%3d:%-9s  ", offset, bytes);
538                 if(bytes.size()>9)
539                         result += "\n               ";
540                 result += decompiled;
541                 result += '\n';
542         }
543
544         return result;
545 }
546
547 string Regex::disassemble_instruction(Code::const_iterator &i) const
548 {
549         Instruction instr = static_cast<Instruction>(*i++);
550
551         switch(instr)
552         {
553         case JUMP:
554                 {
555                         Offset offset = read_int<Offset>(i);
556                         return format("JUMP %+d (%d)", offset, (i-code.begin())+offset);
557                 }
558         case ND_JUMP:
559                 {
560                         Offset offset = read_int<Offset>(i);
561                         return format("ND_JUMP %+d (%d)", offset, (i-code.begin())+offset);
562                 }
563         case GROUP_BEGIN:
564                 return format("GROUP_BEGIN %d", read_int<Index>(i));
565         case GROUP_END:
566                 return format("GROUP_END %d", read_int<Index>(i));
567         case NEGATE:
568                 return "NEGATE";
569         case MATCH_BEGIN:
570                 return "MATCH_BEGIN";
571         case MATCH_END:
572                 return "MATCH_END";
573         case MATCH_CHAR:
574                 {
575                         unsigned char c = *i++;
576                         if(c>=0x20 && c<=0x7E)
577                                 return format("MATCH_CHAR '%c'", c);
578                         else
579                                 return format("MATCH_CHAR %d", c);
580                 }
581                 break;
582         case MATCH_RANGE:
583                 {
584                         int begin = *i++;
585                         int end = *i++;
586                         return format("MATCH_RANGE %d-%d", begin, end);
587                 }
588         case MATCH_MASK:
589                 {
590                         string result = "MATCH_MASK";
591                         for(unsigned j=0; j<32; ++j)
592                                 result += format(" %02X", *i++);
593                         return result;
594                 }
595         case MATCH_ANY:
596                 return "MATCH_ANY";
597         default:
598                 return format("UNKNOWN %d", instr);
599         }
600 }
601
602 } // namespace Msp