文章

一个简便的多态类自动注册设施

实际用例:GitHub

事情是这样的,假设你在用C++做OOP,设计了一个作为接口的纯虚类,然后底下有N个public继承出来的子类,可以用某个工厂函数获取到std::unique_ptr<Base>,然后就可以愉快地调用接口了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Base
{
public:
    virtual void play() = 0;
    virtual ~Base() = default;
};
...
std::unique_ptr<Base> Create();

int main()
{
    auto p = Create();
    p->play();
}

那么在写完每个子类后,我们都得去Create()函数里面添加相对应的代码(假设创建的具体子类由别处的枚举项变量指定):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
enum struct Select
{
    D1, D2, D3
};

std::unique_ptr<Base> Create()
{
    switch (current)
    {
        case Select::D1:
            return new Derived1;
        case Select::D2:
            return new Derived2;
        case Select::D3:
            return new Derived3;
        default:
            return nullptr;
    }
}

这无疑是繁琐的,且容易忘记添加新类代码导致出错,另外还要检查空指针。

理想的简化流程是,每当写完一个新子类后,我们就不需要做其他额外的工作了,子类会自动将自身注册到Create()函数的逻辑里面。

那么我们马上就会想到的一个思路就是,在写子类定义的时候做一点手脚,在里面注册自身,然后将共同的代码提取成宏,就可以用一句宏解决问题了。

譬如说有个子类长这样:

1
2
3
4
5
6
7
class Derived1 : public Base
{
public:
    void play() override;
private:
    // data members
};

很明显,一个类只需要注册一次,且必须在其任一实例产生前注册,我们可以利用类的静态变量来做小动作:

1
2
3
4
5
6
7
8
class Derived1 : public Base
{
    inline static bool _is_registered = [] {
        // registering...
        return true;
    }();
    ...
};

我们用一个返回true的lambda来注册,程序启动时就会执行。这里用到的内联静态变量特性需要C++17,没有的话就得写在类外面。

那么我们要用什么来保存注册信息呢,索引项是枚举,得到一个类实例,这里可以将产生不同子类实例的代码提取为不同的函数,用枚举索引函数指针,就能免去重复代码了。

我们用std::unordered_map<Select, std::unique_ptr<Base>(*)()>作为容器,也将其作为静态变量放在最顶层的接口类里:

1
2
3
4
5
6
7
8
9
10
class Base
{
protected:
    static auto& _getMap()
    {
        static std::unordered_map<Select, std::unique_ptr<Base>(*)()> map;
        return map;
    }
    ...
};

依最小作用域原则,这里把map放在函数里,注意函数为protected,这样内部的所有子类都能访问到。

现在我们就能进行注册了:

1
2
3
4
5
inline static bool _is_registered = [] {
    auto& map = _getMap();
    map[Select::Derived1] = [] { return new Derived1; };
    return true;
}();

但是到这里就会悲催地发现编译无法通过,原因是我们在Derived1还没定义完的时候就写下了new Derived1,使用了一个不完整类型。

解决的方法就是利用函数模板的延后实例化机制来解决,我们只需要把new语句放到函数模板里面去就行了,使用时传入要new的类名。我们把这个函数模板也放到接口类里面去:

1
2
3
4
5
6
7
8
9
10
11
class Base
{
    ...
protected:
    template<typename T>
    static std::unique_ptr<Base> _FactoryNew()
    {
        return std::make_unique<T>();
    }
    ...
};

这样在子类里传_FactoryNew<Derived1>作为函数指针就解决了。注意使用此设施后枚举项的名字要改成和子类名一样。

一切完成后,Create函数就一劳永逸地改成下面这个样子:

1
2
3
4
5
6
7
std::unique_ptr<Base> Base::Create()
{
    auto& map = Base::_getMap();
    auto& key = current;
    assert(map.find(key) != map.end());
    return map[key]();
}

Create()放到Base里作为static函数有两个好处,一是可访问到_getMap而不用将其改成public暴露出来,二是让接口使用更清晰明了。

assert的作用是保证使用前子类已注册过,防止忘记注册。

最后我们把这些小动作提取成宏,并加入错误检测:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#define FACTORY_MAP_DEFINE_(class_name, key_scope) \
protected: \
    static auto& _getMap() \
    { \
        static std::unordered_map<key_scope, std::unique_ptr<class_name>(*)()> map; \
        return map; \
    } \
    template<typename T> \
    static std::unique_ptr<class_name> _FactoryNew() \
    { \
        return std::make_unique<T>(); \
    }

#define FACTORY_MAP_REGISTER_(class_name, key_scope) \
private: \
    inline static bool _is_registered = [] \
        { \
            auto& map = class_name::_getMap(); \
            /* avoid register more than once */ \
            assert(map.find(key_scope::class_name) == map.end()); \
            map[key_scope::class_name] = class_name::_FactoryNew<class_name>; \
            return true; \
        }()

#define FACTORY_MAP_DEFINE(class_name) FACTORY_MAP_DEFINE_(class_name, KEY_SCOPE)
#define FACTORY_MAP_REGISTER(class_name) FACTORY_MAP_REGISTER_(class_name, KEY_SCOPE)

#define KEY_SCOPE Select

这里assert的作用是保证注册前必须没有注册过,避免注册多次,与第一个assert配合保证子类只注册一次。

最终我们就可以这样写,用一行解决问题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Base
{
    FACTORY_MAP_DEFINE(Base);
public:
    static std::unique_ptr<Base> Create();
public:
    virtual void play() = 0;
    virtual ~Base() = default;
};

class Derived1 :public Base
{
    FACTORY_MAP_REGISTER(Derived1);
public:
    void play() override;
};

class Derived2 :public Base
{
    FACTORY_MAP_REGISTER(Derived2);
public:
    void play() override;
};
...

本文由作者按照 CC BY-NC-SA 4.0 进行授权