]> git.tdb.fi Git - r2c2.git/blobdiff - source/libr2c2/trainroutemetric.cpp
Add a distance metric to turn the routing into an A* search
[r2c2.git] / source / libr2c2 / trainroutemetric.cpp
diff --git a/source/libr2c2/trainroutemetric.cpp b/source/libr2c2/trainroutemetric.cpp
new file mode 100644 (file)
index 0000000..a2cffbc
--- /dev/null
@@ -0,0 +1,90 @@
+#include <list>
+#include "track.h"
+#include "trackchain.h"
+#include "trainroutemetric.h"
+
+using namespace std;
+
+namespace R2C2 {
+
+TrainRouteMetric::TrainRouteMetric(const TrackChain &tc)
+{
+       const TrackChain::TrackSet &ctracks = tc.get_tracks();
+       for(TrackChain::TrackSet::const_iterator i=ctracks.begin(); i!=ctracks.end(); ++i)
+       {
+               unsigned nls = (*i)->get_n_link_slots();
+               for(unsigned j=0; j<nls; ++j)
+                       if(Track *link = (*i)->get_link(j))
+                               if(!ctracks.count(link))
+                                       goals.push_back(TrackIter(*i, j));
+       }
+
+       list<TrackIter> queue;
+       for(vector<Goal>::iterator i=goals.begin(); i!=goals.end(); ++i)
+       {
+               tracks[Key(i->track.track(), i->track.entry())] = Data(0, &*i);
+               queue.push_back(i->track);
+       }
+
+       while(!queue.empty())
+       {
+               TrackIter track = queue.front();
+               queue.pop_front();
+               const Data &data = tracks[Key(track.track(), track.entry())];
+
+               const TrackType::Endpoint &ep = track.endpoint();
+               for(unsigned i=0; ep.paths>>i; ++i)
+                       if(ep.has_path(i))
+                       {
+                               TrackIter next = track.next(i);
+                               if(!next)
+                                       continue;
+
+                               Data &target = tracks[Key(next.track(), next.entry())];
+                               float dist = data.distance+track->get_type().get_path_length(i);
+                               if(target.distance<0 || target.distance>dist)
+                               {
+                                       target = Data(dist, data.goal);
+                                       queue.push_back(next);
+                               }
+                       }
+       }
+}
+
+void TrainRouteMetric::chain_to(const TrainRouteMetric &metric)
+{
+       for(vector<Goal>::iterator i=goals.begin(); i!=goals.end(); ++i)
+               i->base_distance = metric.get_distance_from(*i->track.track(), i->track.entry());
+}
+
+float TrainRouteMetric::get_distance_from(const Track &track, unsigned exit) const
+{
+       map<Key, Data>::const_iterator i = tracks.find(Key(&track, exit));
+       if(i==tracks.end())
+               return -1;
+
+       return i->second.distance+i->second.goal->base_distance;
+}
+
+
+TrainRouteMetric::Goal::Goal():
+       base_distance(0)
+{ }
+
+TrainRouteMetric::Goal::Goal(const TrackIter &t):
+       track(t),
+       base_distance(0)
+{ }
+
+
+TrainRouteMetric::Data::Data():
+       distance(-1),
+       goal(0)
+{ }
+
+TrainRouteMetric::Data::Data(float d, const Goal *g):
+       distance(d),
+       goal(g)
+{ }
+
+} // namespace R2C2