--- /dev/null
+#ifndef MSP_GEOMETRY_EXTRUDEDSHAPE_H_
+#define MSP_GEOMETRY_EXTRUDEDSHAPE_H_
+
+#include <algorithm>
+#include <cmath>
+#include "shape.h"
+
+namespace Msp {
+namespace Geometry {
+
+/**
+A shape embedded in space of dimension higher by one and extruded towards the
+highest dimension. As an example, extruding a circle creates a cylinder. The
+base shape's orientation is not changed.
+*/
+template<typename T, unsigned D>
+class ExtrudedShape: public Shape<T, D>
+{
+private:
+ Shape<T, D-1> *base;
+ T length;
+
+public:
+ ExtrudedShape(const Shape<T, D-1> &, T);
+ ExtrudedShape(const ExtrudedShape &);
+ ExtrudedShape &operator=(const ExtrudedShape &);
+ virtual ~ExtrudedShape();
+
+ virtual ExtrudedShape *clone() const;
+
+ const Shape<T, D-1> &get_base() const { return *base; }
+ T get_length() const { return length; }
+
+ virtual HyperBox<T, D> get_axis_aligned_bounding_box() const;
+ virtual bool contains(const LinAl::Vector<T, D> &) const;
+ virtual bool check_intersection(const Ray<T, D> &) const;
+ virtual unsigned get_max_ray_intersections() const;
+ virtual unsigned get_intersections(const Ray<T, D> &, SurfacePoint<T, D> *, unsigned) const;
+};
+
+template<typename T, unsigned D>
+inline ExtrudedShape<T, D>::ExtrudedShape(const Shape<T, D-1> &b, T l):
+ length(l)
+{
+ if(l<=0)
+ throw std::invalid_argument("ExtrudedShape::ExtrudedShape");
+
+ base = b.clone();
+}
+
+template<typename T, unsigned D>
+inline ExtrudedShape<T, D>::ExtrudedShape(const ExtrudedShape<T, D> &other):
+ base(other.base.clone()),
+ length(other.length)
+{ }
+
+template<typename T, unsigned D>
+inline ExtrudedShape<T, D> &ExtrudedShape<T, D>::operator=(const ExtrudedShape<T, D> &other)
+{
+ delete base;
+ base = other.base.clone();
+ length = other.length;
+}
+
+template<typename T, unsigned D>
+inline ExtrudedShape<T, D>::~ExtrudedShape()
+{
+ delete base;
+}
+
+template<typename T, unsigned D>
+inline ExtrudedShape<T, D> *ExtrudedShape<T, D>::clone() const
+{
+ return new ExtrudedShape<T, D>(*base, length);
+}
+
+template<typename T, unsigned D>
+inline HyperBox<T, D> ExtrudedShape<T, D>::get_axis_aligned_bounding_box() const
+{
+ HyperBox<T, D-1> base_bbox = base->get_axis_aligned_bounding_box();
+ return HyperBox<T, D>(LinAl::Vector<T, D>(base_bbox.get_dimensions(), length));
+}
+
+template<typename T, unsigned D>
+inline bool ExtrudedShape<T, D>::contains(const LinAl::Vector<T, D> &point) const
+{
+ using std::abs;
+
+ if(abs(point[D-1])>length/T(2))
+ return false;
+
+ return base->contains(LinAl::Vector<T, D-1>(point));
+}
+
+template<typename T, unsigned D>
+inline bool ExtrudedShape<T, D>::check_intersection(const Ray<T, D> &ray) const
+{
+ return get_intersections(ray, 0, 1);
+}
+
+template<typename T, unsigned D>
+inline unsigned ExtrudedShape<T, D>::get_max_ray_intersections() const
+{
+ return std::max(base->get_max_ray_intersections(), 2U);
+}
+
+template<typename T, unsigned D>
+inline unsigned ExtrudedShape<T, D>::get_intersections(const Ray<T, D> &ray, SurfacePoint<T, D> *points, unsigned size) const
+{
+ using std::abs;
+ using std::sqrt;
+ using std::swap;
+
+ unsigned n = 0;
+ T half_length = length/T(2);
+ const LinAl::Vector<T, D> &ray_start = ray.get_start();
+ const LinAl::Vector<T, D> &ray_direction = ray.get_direction();
+ LinAl::Vector<T, D-1> base_dir(ray_direction);
+
+ /* If the ray does not degenerate to a point in the base space, it could
+ intersect the base shape. */
+ if(inner_product(base_dir, base_dir)!=T(0))
+ {
+ T offset = T();
+ T limit = T();
+ if(ray.get_direction()[D-1]!=T(0))
+ {
+ offset = (half_length-ray_start[D-1])/ray_direction[D-1];
+ limit = (-half_length-ray_start[D-1])/ray_direction[D-1];
+ if(offset>limit)
+ swap(offset, limit);
+ if(offset<T(0))
+ offset = T(0);
+ }
+ T distortion = base_dir.norm();
+ Ray<T, D-1> base_ray(LinAl::Vector<T, D-1>(ray_start+ray_direction*offset),
+ base_dir, (limit-offset)*distortion);
+
+ SurfacePoint<T, D-1> *base_points = 0;
+ if(points)
+ /* Shamelessly reuse the provided storage. Align to the end of the array
+ so processing can start from the first (nearest) point. */
+ base_points = reinterpret_cast<SurfacePoint<T, D-1> *>(points+size)-size;
+
+ unsigned count = base->get_intersections(base_ray, base_points, size);
+ for(unsigned i=0; i<count; ++i)
+ {
+ if(points)
+ {
+ T x = offset+base_points[i].distance/distortion;
+ points[n].position = ray_start+ray_direction*x;
+ points[n].normal = LinAl::Vector<T, D>(base_points[i].normal, T(0));
+ points[n].distance = x;
+ }
+
+ ++n;
+ if(n==size)
+ return n;
+ }
+ }
+
+ /* If the ray is not parallel to the base space, it may pass through the
+ caps. */
+ if(ray_direction[D-1])
+ {
+ for(int i=-1; i<=1; i+=2)
+ {
+ T x = (half_length*i-ray_start[D-1])/ray_direction[D-1];
+ if(!ray.check_limits(x))
+ continue;
+
+ LinAl::Vector<T, D> p = ray_start+ray_direction*x;
+ if(base->contains(LinAl::Vector<T, D-1>(p)) && n<size)
+ {
+ if(points)
+ {
+ points[n].position = p;
+ points[n].normal = LinAl::Vector<T, D>();
+ points[n].normal[D-1] = i;
+ points[n].distance = x;
+
+ if(n==1 && x<points[0].distance)
+ swap(points[0], points[1]);
+ }
+
+ ++n;
+ if(n==size)
+ return n;
+ }
+ }
+ }
+
+ return n;
+}
+
+} // namespace Geometry
+} // namespace Msp
+
+#endif