void SystemScheduler::schedule()
{
for(GraphNode &n: nodes)
+ {
n.prerequisites = 0;
+ n.commit_barrier = 0;
+ }
+ // Collect basic prerequisites from system dependencies
for(size_t i=0; i+1<nodes.size(); ++i)
for(size_t j=i+1; j<nodes.size(); ++j)
{
else if(order>0)
nodes[i].prerequisites |= 1ULL<<j;
}
+
+ // Make prerequisites transitive
+ for(GraphNode &n: nodes)
+ for(size_t i=0; i<nodes.size(); )
+ {
+ if((n.prerequisites&(1ULL<<i)) && (nodes[i].prerequisites&~n.prerequisites))
+ {
+ n.prerequisites |= nodes[i].prerequisites;
+ i = 0;
+ }
+ else
+ ++i;
+ }
+
+ // Check for writes which overlap with reads by other systems
+ for(size_t i=0; i<nodes.size(); ++i)
+ for(size_t j=i+1; j<nodes.size(); ++j)
+ if(!(nodes[i].prerequisites&(1ULL<<j)) && !(nodes[j].prerequisites&(1ULL<<i)))
+ {
+ unsigned overlap = get_overlapping_writes(nodes[i], nodes[j]);
+ if(overlap&1)
+ nodes[i].commit_barrier |= 1ULL<<j;
+ if(overlap&2)
+ nodes[j].commit_barrier |= 1ULL<<i;
+ }
+
+ /* Remove commit barriers which would cause loops and add prerequisites in
+ the reverse direction instead to ensure the read doesn't happen in the middle
+ of committing the write */
+ for(size_t i=0; i<nodes.size(); ++i)
+ {
+ NodeMask check_bits = 0;
+ for(size_t j=0; j<nodes.size(); ++j)
+ if(nodes[j].commit_barrier&(1ULL<<i))
+ check_bits |= 1ULL<<j;
+
+ for(size_t j=0; j<nodes.size(); ++j)
+ if((nodes[i].commit_barrier&(1ULL<<j)) && (nodes[j].prerequisites&check_bits))
+ {
+ nodes[i].commit_barrier &= ~(1ULL<<j);
+ nodes[j].prerequisites |= 1ULL<<i;
+ }
+ }
}
int SystemScheduler::get_order(const GraphNode &node1, const GraphNode &node2)
return data_order;
}
+unsigned SystemScheduler::get_overlapping_writes(const GraphNode &node1, const GraphNode &node2)
+{
+ unsigned result = 0;
+ for_common_deps(node1, node2, [&result](const Reflection::ClassBase *, int flags1, int flags2){
+ if(flags1&System::WRITE)
+ result |= 1;
+ else if(flags2&System::WRITE)
+ result |= 2;
+ });
+
+ return result;
+}
+
template<typename F>
void SystemScheduler::for_common_deps(const GraphNode &node1, const GraphNode &node2, const F &func)
{
AccessGuard::BlockForScope _block;
#endif
- uint64_t pending = (~0ULL)>>(MAX_SYSTEMS-nodes.size());
- while(pending)
+ NodeMask started = 0;
+ finished = 0;
+ committed = 0;
+ NodeMask all = (~0ULL)>>(MAX_SYSTEMS-nodes.size());
+ while(committed!=all)
{
- for(size_t i=0; i<nodes.size(); ++i)
- if(!(pending&nodes[i].prerequisites))
- {
- run_system(i, dt);
- pending &= ~(1ULL<<i);
- }
+ int run_index = -1;
+ int commit_index = -1;
+ for(size_t i=0; (run_index<0 && i<nodes.size()); ++i)
+ {
+ if(!(started&(1ULL<<i)) && (committed&nodes[i].prerequisites)==nodes[i].prerequisites)
+ run_index = i;
+ if((finished&nodes[i].commit_barrier)==nodes[i].commit_barrier)
+ commit_index = i;
+ }
+
+ if(run_index>=0)
+ {
+ started |= 1ULL<<run_index;
+ run_system(run_index, dt);
+ }
+ else if(commit_index>=0)
+ commit_system(commit_index);
}
}
for(const System::Dependency &d: sys.get_dependencies())
if(Transactor *tract = d.get_transactor())
- {
tract->block(d.get_transact_mode());
+
+ finished |= 1ULL<<index;
+ if((finished&nodes[index].commit_barrier)==nodes[index].commit_barrier)
+ commit_system(index);
+}
+
+void SystemScheduler::commit_system(size_t index)
+{
+ System &sys = *nodes[index].system;
+ for(const System::Dependency &d: sys.get_dependencies())
+ if(Transactor *tract = d.get_transactor())
tract->commit(d.get_transact_mode());
- }
+
+ committed |= 1ULL<<index;
}
} // namespace Msp::Game
class MSPGAME_API SystemScheduler
{
public:
- using PrerequisiteMask = std::uint64_t;
- static constexpr unsigned MAX_SYSTEMS = sizeof(PrerequisiteMask)*8;
+ using NodeMask = std::uint64_t;
+ static constexpr unsigned MAX_SYSTEMS = sizeof(NodeMask)*8;
struct GraphNode
{
System *system = nullptr;
Reflection::ClassBase *type = nullptr;
- PrerequisiteMask prerequisites = 0;
+ NodeMask prerequisites = 0;
+ NodeMask commit_barrier = 0;
};
private:
std::vector<GraphNode> nodes;
bool pending_reschedule = false;
+ NodeMask finished = 0;
+ NodeMask committed = 0;
+
public:
SystemScheduler(Reflection::Reflector &r): reflector(r) { }
static int get_order(const GraphNode &, const GraphNode &);
static int get_explicit_order(const GraphNode &, const GraphNode &);
static int get_data_order(const GraphNode &, const GraphNode &);
+ static unsigned get_overlapping_writes(const GraphNode &, const GraphNode &);
template<typename F>
static void for_common_deps(const GraphNode &, const GraphNode &, const F &);
void run(Time::TimeDelta);
private:
void run_system(std::size_t, Time::TimeDelta);
+ void commit_system(std::size_t);
};
} // namespace Msp::Game
void chained_update();
void parallel_access();
void ambiguous_data_order();
+ void commit_barrier();
};
add(&SchedulerTests::chained_update, "Chained update");
add(&SchedulerTests::parallel_access, "Parallel access");
add(&SchedulerTests::ambiguous_data_order, "Ambiguous data order").expect_throw<Game::scheduling_error>();
+ add(&SchedulerTests::commit_barrier, "Commit barrier");
}
void SchedulerTests::unrelated_components()
scheduler.schedule();
}
+
+void SchedulerTests::commit_barrier()
+{
+ Env env;
+
+ auto &sys1 = env.stage.add_system<Sys<Dep<A, READ_OLD>, Dep<B, UPDATE>>>();
+ auto &sys2 = env.stage.add_system<Sys<Dep<A, READ_OLD>, Dep<B, CHAINED_UPDATE>, Dep<C, UPDATE>>>();
+ auto &sys3 = env.stage.add_system<Sys<Dep<A, UPDATE>, Dep<C, READ_OLD>>>();
+ auto &sys4 = env.stage.add_system<Sys<Dep<A, READ_OLD>, Dep<C, CHAINED_UPDATE>>>();
+
+ Game::SystemScheduler scheduler(env.reflector);
+ scheduler.add_system(sys1);
+ scheduler.add_system(sys2);
+ scheduler.add_system(sys3);
+ scheduler.add_system(sys4);
+
+ scheduler.schedule();
+ const auto &graph = scheduler.get_graph();
+
+ EXPECT_EQUAL(graph.size(), 4);
+ EXPECT_EQUAL(graph[0].prerequisites, 0);
+ EXPECT_EQUAL(graph[0].commit_barrier, 0);
+ EXPECT_EQUAL(graph[1].prerequisites, 1);
+ EXPECT_EQUAL(graph[1].commit_barrier, 4);
+ EXPECT_EQUAL(graph[2].prerequisites, 0);
+ EXPECT_EQUAL(graph[2].commit_barrier, 3);
+ EXPECT_EQUAL(graph[3].prerequisites, 7);
+ EXPECT_EQUAL(graph[3].commit_barrier, 4);
+}