Implement an asynchronous name resolver class
[libs/net.git] / source / net / resolve.cpp
1 #ifdef WIN32
2 #define _WIN32_WINNT 0x0501
3 #include <ws2tcpip.h>
4 #else
5 #include <netdb.h>
6 #endif
7 #include <msp/core/systemerror.h>
8 #include <msp/strings/format.h>
9 #include "sockaddr_private.h"
10 #include "socket.h"
11 #include "resolve.h"
12
13 using namespace std;
14
15 namespace {
16
17 void parse_host_serv(const string &str, string &host, string &serv)
18 {
19         if(str[0]=='[')
20         {
21                 string::size_type bracket = str.find(']');
22                 host = str.substr(1, bracket-1);
23                 string::size_type colon = str.find(':', bracket);
24                 if(colon!=string::npos)
25                         serv = str.substr(colon+1);
26         }
27         else
28         {
29                 string::size_type colon = str.find(':');
30                 if(colon!=string::npos)
31                 {
32                         host = str.substr(0, colon);
33                         serv = str.substr(colon+1);
34                 }
35                 else
36                         host = str;
37         }
38 }
39
40 }
41
42
43 namespace Msp {
44 namespace Net {
45
46 SockAddr *resolve(const string &host, const string &serv, Family family)
47 {
48         const char *chost = (host.empty() ? 0 : host.c_str());
49         const char *cserv = (serv.empty() ? 0 : serv.c_str());
50         int flags = 0;
51         if(host=="*")
52         {
53                 flags = AI_PASSIVE;
54                 chost = 0;
55         }
56
57         addrinfo hints = { flags, family_to_sys(family), 0, 0, 0, 0, 0, 0 };
58         addrinfo *res;
59
60         int err = getaddrinfo(chost, cserv, &hints, &res);
61         if(err==0)
62         {
63                 SockAddr::SysAddr sa;
64                 sa.size = res->ai_addrlen;
65                 const char *sptr = reinterpret_cast<const char *>(res->ai_addr);
66                 char *dptr = reinterpret_cast<char *>(&sa.addr);
67                 copy(sptr, sptr+res->ai_addrlen, dptr);
68                 SockAddr *addr = SockAddr::new_from_sys(sa);
69                 freeaddrinfo(res);
70                 return addr;
71         }
72         else
73 #ifdef WIN32
74                 throw system_error("getaddrinfo", WSAGetLastError());
75 #else
76                 throw system_error("getaddrinfo", gai_strerror(err));
77 #endif
78 }
79
80 SockAddr *resolve(const string &str, Family family)
81 {
82         string host, serv;
83         parse_host_serv(str, host, serv);
84
85         return resolve(host, serv, family);
86 }
87
88
89 Resolver::Resolver():
90         event_disp(0),
91         next_tag(1)
92 {
93         thread.get_notify_pipe().signal_data_available.connect(sigc::mem_fun(this, &Resolver::task_done));
94 }
95
96 void Resolver::use_event_dispatcher(IO::EventDispatcher *ed)
97 {
98         if(event_disp)
99                 event_disp->remove(thread.get_notify_pipe());
100         event_disp = ed;
101         if(event_disp)
102                 event_disp->add(thread.get_notify_pipe());
103 }
104
105 unsigned Resolver::resolve(const string &host, const string &serv, Family family)
106 {
107         Task task;
108         task.tag = next_tag++;
109         task.host = host;
110         task.serv = serv;
111         task.family = family;
112         thread.add_task(task);
113         return task.tag;
114 }
115
116 unsigned Resolver::resolve(const string &str, Family family)
117 {
118         string host, serv;
119         parse_host_serv(str, host, serv);
120
121         return resolve(host, serv, family);
122 }
123
124 void Resolver::tick()
125 {
126         if(IO::poll(thread.get_notify_pipe(), IO::P_INPUT, Time::zero))
127                 task_done();
128 }
129
130 void Resolver::task_done()
131 {
132         char buf[64];
133         thread.get_notify_pipe().read(buf, sizeof(buf));
134
135         while(Task *task = thread.get_complete_task())
136         {
137                 if(task->addr)
138                         signal_address_resolved.emit(task->tag, *task->addr);
139                 else if(task->error)
140                         signal_resolve_failed.emit(task->tag, *task->error);
141                 thread.pop_complete_task();
142         }
143 }
144
145
146 Resolver::Task::Task():
147         tag(0),
148         family(UNSPEC),
149         addr(0),
150         error(0)
151 { }
152
153
154 Resolver::WorkerThread::WorkerThread():
155         Thread("Resolver"),
156         sem(1),
157         done(false)
158 {
159         launch();
160 }
161
162 Resolver::WorkerThread::~WorkerThread()
163 {
164         done = true;
165         sem.signal();
166         join();
167 }
168
169 void Resolver::WorkerThread::add_task(const Task &t)
170 {
171         MutexLock lock(queue_mutex);
172         bool was_starved = (queue.empty() || queue.back().is_complete());
173         queue.push_back(t);
174         if(was_starved)
175                 sem.signal();
176 }
177
178 Resolver::Task *Resolver::WorkerThread::get_complete_task()
179 {
180         MutexLock lock(queue_mutex);
181         if(!queue.empty() && queue.front().is_complete())
182                 return &queue.front();
183         else
184                 return 0;
185 }
186
187 void Resolver::WorkerThread::pop_complete_task()
188 {
189         MutexLock lock(queue_mutex);
190         if(!queue.empty() && queue.front().is_complete())
191         {
192                 delete queue.front().addr;
193                 delete queue.front().error;
194                 queue.pop_front();
195         }
196 }
197
198 void Resolver::WorkerThread::main()
199 {
200         bool wait = true;
201         while(!done)
202         {
203                 if(wait)
204                         sem.wait();
205                 wait = false;
206
207                 Task *task = 0;
208                 {
209                         MutexLock lock(queue_mutex);
210                         for(list<Task>::iterator i=queue.begin(); (!task && i!=queue.end()); ++i)
211                                 if(!i->is_complete())
212                                         task = &*i;
213                 }
214
215                 if(task)
216                 {
217                         try
218                         {
219                                 SockAddr *addr = Net::resolve(task->host, task->serv, task->family);
220                                 {
221                                         MutexLock lock(queue_mutex);
222                                         task->addr = addr;
223                                 }
224                         }
225                         catch(const runtime_error &e)
226                         {
227                                 MutexLock lock(queue_mutex);
228                                 task->error = new runtime_error(e);
229                         }
230                         notify_pipe.put(1);
231                 }
232                 else
233                         wait = true;
234         }
235 }
236
237                 /*sockaddr sa;
238                 unsigned size = fill_sockaddr(sa);
239                 char hst[128];
240                 char srv[128];
241                 int err = getnameinfo(&sa, size, hst, 128, srv, 128, 0);
242                 if(err==0)
243                 {
244                         host = hst;
245                         serv = srv;
246                 }*/
247
248 } // namespace Net
249 } // namespace Msp