]> git.tdb.fi Git - libs/core.git/blob - source/strings/regex.cpp
Fix signedness errors from MSVC
[libs/core.git] / source / strings / regex.cpp
1 #include <limits>
2 #include <list>
3 #include <stack>
4 #include <vector>
5 #include "format.h"
6 #include "regex.h"
7
8 using namespace std;
9
10 namespace {
11
12 /** Writes an integer to a Regex code string, in little-endian order. */
13 template<typename T>
14 void write_int(T n, basic_string<unsigned char> &code)
15 {
16         for(unsigned i=0; i<sizeof(T); ++i)
17                 code += (n>>(i*8))&0xFF;
18 }
19
20 /** Reads an integer from a Regex code string, in little-endian order. */
21 template<typename T>
22 T read_int(basic_string<unsigned char>::const_iterator &c)
23 {
24         T result = 0;
25         for(unsigned i=0; i<sizeof(T); ++i)
26                 result += (*c++)<<(i*8);
27         return result;
28 }
29
30 }
31
32
33 namespace Msp {
34
35 bad_regex::bad_regex(const string &w, const string &e, const string::const_iterator &i):
36         logic_error(w+"\n"+make_where(e, i))
37 { }
38
39 string bad_regex::make_where(const string &e, const string::const_iterator &i)
40 {
41         string result;
42         string::size_type offset = i-e.begin();
43         if(offset>40)
44         {
45                 result = e.substr(offset-40, 60);
46                 offset = 40;
47         }
48         else
49                 result = e.substr(0, 60);
50         result += '\n';
51         result.append(offset, ' ');
52         result += '^';
53         return result;
54 }
55
56
57 Regex::Regex(const string &expr)
58 {
59         n_groups = 0;
60         auto iter = expr.begin();
61         code = compile(expr, iter, n_groups, false);
62         ++n_groups;
63 }
64
65 Regex::Code Regex::compile(const string &expr, string::const_iterator &iter, unsigned &group, bool branch)
66 {
67         bool has_branches = false;
68         stack<string::const_iterator> parens;
69         bool escape = false;
70         unsigned bracket = 0;
71         string::const_iterator end;
72         for(end=iter; end!=expr.end(); ++end)
73         {
74                 if(escape)
75                         escape = false;
76                 else if(bracket)
77                 {
78                         if(bracket==3 && *end==']')
79                                 bracket = 0;
80                         else if(bracket==1 && *end=='^')
81                                 bracket = 2;
82                         else
83                                 bracket = 3;
84                 }
85                 else if(*end=='\\')
86                         escape = true;
87                 else if(*end=='(')
88                         parens.push(end);
89                 else if(*end==')')
90                 {
91                         if(parens.empty())
92                         {
93                                 if(group==0)
94                                         throw bad_regex("unmatched ')'", expr, end);
95                                 else
96                                         break;
97                         }
98                         parens.pop();
99                 }
100                 else if(*end=='|' && parens.empty())
101                 {
102                         if(branch)
103                                 break;
104                         else
105                                 has_branches = true;
106                 }
107                 else if(*end=='[')
108                         bracket = 1;
109         }
110
111         if(!parens.empty())
112                 throw bad_regex("unmatched '('", expr, parens.top());
113
114         Code result;
115
116         unsigned this_group = group;
117         if(!branch)
118         {
119                 result += GROUP_BEGIN;
120                 write_int<Index>(this_group, result);
121         }
122
123         const unsigned jump_size = 1+sizeof(Offset);
124
125         if(!has_branches)
126         {
127                 for(auto i=iter; i!=end;)
128                 {
129                         Code atom = parse_atom(expr, i, group);
130
131                         Count repeat_min = 1;
132                         Count repeat_max = 1;
133                         parse_repeat(expr, i, repeat_min, repeat_max);
134
135                         for(unsigned j=0; j<repeat_min; ++j)
136                                 result += atom;
137                         if(repeat_max==numeric_limits<Count>::max())
138                         {
139                                 if(repeat_min==0)
140                                 {
141                                         result += ND_JUMP;
142                                         write_int<Offset>(atom.size()+jump_size, result);
143                                         result += atom;
144                                 }
145                                 result += ND_JUMP;
146                                 write_int<Offset>(-static_cast<Offset>(atom.size()+jump_size), result);
147                         }
148                         else if(repeat_max>repeat_min)
149                         {
150                                 for(unsigned j=repeat_min; j<repeat_max; ++j)
151                                 {
152                                         result += ND_JUMP;
153                                         write_int<Offset>((repeat_max-j)*(atom.size()+jump_size)-jump_size, result);
154                                         result += atom;
155                                 }
156                         }
157                 }
158         }
159         else
160         {
161                 vector<Code> branches;
162                 for(auto i=iter;;)
163                 {
164                         branches.push_back(compile(expr, i, group, true));
165                         if(i==end)
166                                 break;
167                         ++i;
168                 }
169
170                 unsigned n_branches = branches.size();
171
172                 Offset offset = (n_branches-1)*jump_size+branches.front().size();
173                 for(auto i=++branches.begin(); i!=branches.end(); ++i)
174                 {
175                         result += ND_JUMP;
176                         write_int<Offset>(offset, result);
177                         offset += i->size();
178                 }
179
180                 for(auto i=branches.begin(); i!=branches.end();)
181                 {
182                         result += *i;
183                         offset -= i->size()+jump_size;
184                         ++i;
185                         if(i!=branches.end())
186                         {
187                                 result += JUMP;
188                                 write_int<Offset>(offset, result);
189                         }
190                 }
191         }
192
193         if(!branch)
194         {
195                 result += GROUP_END;
196                 write_int<Index>(this_group, result);
197         }
198
199         iter = end;
200
201         return result;
202 }
203
204 Regex::Code Regex::parse_atom(const string &expr, string::const_iterator &i, unsigned &group)
205 {
206         Code result;
207
208         if(i==expr.end())
209                 return result;
210
211         bool flag = false;
212         if(*i=='\\')
213         {
214                 if(++i==expr.end())
215                         throw bad_regex("stray backslash", expr, i-1);
216                 flag = true;
217         }
218
219         if(!flag)
220         {
221                 if(*i=='*' || *i=='{' || *i=='}' || *i=='+' || *i=='?' || *i=='|' || *i==')')
222                         throw bad_regex("invalid atom", expr, i);
223                 else if(*i=='[')
224                         return parse_brackets(expr, i);
225                 else if(*i=='.')
226                         result += MATCH_ANY;
227                 else if(*i=='^')
228                         result += MATCH_BEGIN;
229                 else if(*i=='$')
230                         result += MATCH_END;
231                 else if(*i=='(')
232                 {
233                         ++group;
234                         result = compile(expr, ++i, group, false);
235                 }
236                 else
237                         flag = true;
238         }
239
240         if(flag)
241         {
242                 result += MATCH_CHAR;
243                 result += *i;
244         }
245
246         ++i;
247
248         return result;
249 }
250
251 bool Regex::parse_repeat(const string &expr, string::const_iterator &i, Count &rmin, Count &rmax)
252 {
253         if(*i!='*' && *i!='+' && *i!='?' && *i!='{')
254                 return false;
255
256         if(*i=='*' || *i=='+')
257                 rmax = numeric_limits<Count>::max();
258         if(*i=='*' || *i=='?')
259                 rmin = 0;
260         if(*i=='{')
261         {
262                 auto begin = i;
263
264                 rmin = 0;
265                 for(++i; isdigit(*i); ++i)
266                         rmin = rmin*10+(*i-'0');
267
268                 if(*i==',')
269                 {
270                         ++i;
271                         if(*i!='}')
272                         {
273                                 rmax = 0;
274                                 for(; isdigit(*i); ++i)
275                                         rmax = rmax*10+(*i-'0');
276                                 if(rmax<rmin)
277                                         throw bad_regex("invalid bound", expr, begin);
278                         }
279                         else
280                                 rmax = numeric_limits<Count>::max();
281                 }
282                 else
283                         rmax = rmin;
284                 if(*i!='}')
285                         throw bad_regex("invalid bound", expr, begin);
286         }
287
288         ++i;
289
290         return true;
291 }
292
293 Regex::Code Regex::parse_brackets(const string &str, string::const_iterator &iter)
294 {
295         auto begin = iter;
296         Code result;
297
298         ++iter;
299         bool neg = false;
300         if(*iter=='^')
301         {
302                 neg = true;
303                 ++iter;
304         }
305
306         auto end = iter;
307         for(; (end!=str.end() && (end==iter || *end!=']')); ++end) ;
308         if(end==str.end())
309                 throw bad_regex("unmatched '['", str, begin);
310
311         unsigned char mask[32] = {0};
312         unsigned type = 0;
313         bool range = false;
314         unsigned char first = 0, last = 0;
315         for(auto i=iter; i!=end; ++i)
316         {
317                 unsigned char c = *i;
318                 if(range)
319                 {
320                         last = c;
321                         for(unsigned j=first; j<=c; ++j)
322                                 mask[j>>3] |= 1<<(j&7);
323                         range = false;
324                         if(type<2)
325                                 type = 2;
326                 }
327                 else if(c=='-' && i!=iter && end-i>1)
328                         range = true;
329                 else
330                 {
331                         first = c;
332                         mask[c>>3] |= 1<<(c&7);
333                         if(type==0)
334                                 type = 1;
335                         else
336                                 type = 3;
337                 }
338         }
339
340         if(neg)
341                 result += NEGATE;
342
343         if(type==1)
344         {
345                 result += MATCH_CHAR;
346                 result += first;
347         }
348         else if(type==2)
349         {
350                 result += MATCH_RANGE;
351                 result += first;
352                 result += last;
353         }
354         else
355         {
356                 result += MATCH_MASK;
357                 result.append(mask, 32);
358         }
359
360         iter = end;
361         ++iter;
362
363         return result;
364 }
365
366 RegMatch Regex::match(const string &str) const
367 {
368         vector<RegMatch::Group> groups(n_groups);
369
370         for(auto i=str.begin(); i!=str.end(); ++i)
371                 if(run(str, i, groups))
372                         return RegMatch(str, groups);
373
374         return RegMatch();
375 }
376
377 bool Regex::run(const string &str, const string::const_iterator &begin, vector<RegMatch::Group> &groups) const
378 {
379         bool result = false;
380         list<RunContext> ctx;
381         ctx.push_back(RunContext());
382         ctx.front().citer = code.begin();
383         ctx.front().groups.resize(groups.size());
384
385         for(auto i=begin;;)
386         {
387                 int c;
388                 if(i!=str.end())
389                         c = *i;
390                 else
391                         c = -1;
392
393                 for(auto j=ctx.begin(); j!=ctx.end();)
394                 {
395                         bool terminate = false;
396                         bool negate_match = false;
397                         for(; j->citer!=code.end();)
398                         {
399                                 Instruction instr = static_cast<Instruction>(*j->citer++);
400
401                                 if(instr==NEGATE)
402                                         negate_match = true;
403                                 else if(instr==JUMP)
404                                 {
405                                         Offset offset = read_int<Offset>(j->citer);
406                                         j->citer += offset;
407                                 }
408                                 else if(instr==ND_JUMP)
409                                 {
410                                         Offset offset = read_int<Offset>(j->citer);
411                                         ctx.push_back(*j);
412                                         ctx.back().citer += offset;
413                                 }
414                                 else if(instr==GROUP_BEGIN)
415                                 {
416                                         Index n = read_int<Index>(j->citer);
417                                         if(!j->groups[n].match)
418                                                 j->groups[n].begin = i-str.begin();
419                                 }
420                                 else if(instr==GROUP_END)
421                                 {
422                                         Index n = read_int<Index>(j->citer);
423                                         if(!j->groups[n].match)
424                                         {
425                                                 j->groups[n].match = true;
426                                                 j->groups[n].end = i-str.begin();
427                                                 j->groups[n].length = j->groups[n].end-j->groups[n].begin;
428                                         }
429
430                                         if(n==0)
431                                         {
432                                                 result = true;
433                                                 bool better = false;
434                                                 for(unsigned k=0; (k<groups.size() && !better); ++k)
435                                                 {
436                                                         better = group_compare(j->groups[k], groups[k]);
437                                                         if(group_compare(groups[k], j->groups[k]))
438                                                                 break;
439                                                 }
440                                                 if(better)
441                                                         groups = j->groups;
442                                         }
443                                 }
444                                 else
445                                 {
446                                         bool match_result = false;
447                                         bool input_consumed = false;
448                                         if(instr==MATCH_BEGIN)
449                                                 match_result = (i==str.begin());
450                                         else if(instr==MATCH_END)
451                                                 match_result = (i==str.end());
452                                         else if(instr==MATCH_CHAR)
453                                         {
454                                                 match_result = (c==*j->citer++);
455                                                 input_consumed = true;
456                                         }
457                                         else if(instr==MATCH_RANGE)
458                                         {
459                                                 unsigned char first = *j->citer++;
460                                                 unsigned char last = *j->citer++;
461                                                 match_result = (c>=first && c<=last);
462                                                 input_consumed = true;
463                                         }
464                                         else if(instr==MATCH_MASK)
465                                         {
466                                                 if(c>=0 && c<=0xFF)
467                                                 {
468                                                         unsigned char m = *(j->citer+(c>>3));
469                                                         match_result = m&(1<<(c&7));
470                                                 }
471                                                 input_consumed = true;
472                                                 j->citer += 32;
473                                         }
474                                         else if(instr==MATCH_ANY)
475                                         {
476                                                 match_result = true;
477                                                 input_consumed = true;
478                                         }
479                                         else
480                                                 throw logic_error("invalid instruction in regex bytecode");
481
482                                         if(match_result==negate_match)
483                                                 terminate = true;
484                                         negate_match = false;
485
486                                         if(input_consumed || terminate)
487                                                 break;
488                                 }
489                         }
490
491                         if(terminate || j->citer==code.end())
492                                 j = ctx.erase(j);
493                         else
494                                 ++j;
495                 }
496
497                 if(i==str.end() || ctx.empty())
498                         break;
499                 ++i;
500         }
501
502         return result;
503 }
504
505 bool Regex::group_compare(const RegMatch::Group &g1, const RegMatch::Group &g2) const
506 {
507         if(!g1.match)
508                 return false;
509
510         // Any match is better than no match
511         if(!g2.match)
512                 return true;
513
514         // Earlier match is better
515         if(g1.begin<g2.begin)
516                 return true;
517         if(g1.begin>g2.begin)
518                 return false;
519
520         // Longer match at same position is better
521         return g1.end>g2.end;
522 }
523
524 string Regex::disassemble() const
525 {
526         string result;
527
528         for(auto i=code.begin(); i!=code.end();)
529         {
530                 auto j = i;
531                 Offset offset = i-code.begin();
532                 string decompiled = disassemble_instruction(i);
533                 string bytes;
534                 for(; j!=i; ++j)
535                         bytes += format(" %02X", *j);
536                 result += format("%3d:%-9s  ", offset, bytes);
537                 if(bytes.size()>9)
538                         result += "\n               ";
539                 result += decompiled;
540                 result += '\n';
541         }
542
543         return result;
544 }
545
546 string Regex::disassemble_instruction(Code::const_iterator &i) const
547 {
548         Instruction instr = static_cast<Instruction>(*i++);
549
550         switch(instr)
551         {
552         case JUMP:
553                 {
554                         Offset offset = read_int<Offset>(i);
555                         return format("JUMP %+d (%d)", offset, (i-code.begin())+offset);
556                 }
557         case ND_JUMP:
558                 {
559                         Offset offset = read_int<Offset>(i);
560                         return format("ND_JUMP %+d (%d)", offset, (i-code.begin())+offset);
561                 }
562         case GROUP_BEGIN:
563                 return format("GROUP_BEGIN %d", read_int<Index>(i));
564         case GROUP_END:
565                 return format("GROUP_END %d", read_int<Index>(i));
566         case NEGATE:
567                 return "NEGATE";
568         case MATCH_BEGIN:
569                 return "MATCH_BEGIN";
570         case MATCH_END:
571                 return "MATCH_END";
572         case MATCH_CHAR:
573                 {
574                         unsigned char c = *i++;
575                         if(c>=0x20 && c<=0x7E)
576                                 return format("MATCH_CHAR '%c'", c);
577                         else
578                                 return format("MATCH_CHAR %d", c);
579                 }
580                 break;
581         case MATCH_RANGE:
582                 {
583                         int begin = *i++;
584                         int end = *i++;
585                         return format("MATCH_RANGE %d-%d", begin, end);
586                 }
587         case MATCH_MASK:
588                 {
589                         string result = "MATCH_MASK";
590                         for(unsigned j=0; j<32; ++j)
591                                 result += format(" %02X", *i++);
592                         return result;
593                 }
594         case MATCH_ANY:
595                 return "MATCH_ANY";
596         default:
597                 return format("UNKNOWN %d", instr);
598         }
599 }
600
601 } // namespace Msp