Tuesday, May 29, 2012

Get All Derived Types of a Class

A common requirement for all .NET module based applications is the ability to get all types that derive from a given base class. This task sounds easy at first, as the .NET Framework provides a convenient Type.IsSubclassOf method, but it is not! The problem arises when generics are involved, for example when you want to find all types that inherit from a given generic class in a specific assembly. In this case the Type.IsSubclassOf method won't work and will return false.
 
For example, if we have the following C# classes:
public class BaseClass<T>
{
}

public class ChildClass1 : BaseClass<int>
{
}

public class ChildClass2 : BaseClass<string>
{
}

public class GrandChildClass1 : ChildClass1
{
}

We can get all types that derive from the generic base class using the class presented at the end of this article like this:
List<Type> derivedTypes = VType.GetDerivedTypes(typeof(BaseClass<>),
    Assembly.GetExecutingAssembly());

Here's the C# source code of the helper class implementation that we can use to get the derived types of any class (including generic ones):
public static class VType
{
    public static List<Type> GetDerivedTypes(Type baseType, Assembly assembly)
    {
        // Get all types from the given assembly
        Type[] types = assembly.GetTypes();
        List<Type> derivedTypes = new List<Type>();

        for (int i = 0, count = types.Length; i < count; i++)
        {
            Type type = types[i];
            if (VType.IsSubclassOf(type, baseType))
            {
                // The current type is derived from the base type,
                // so add it to the list
                derivedTypes.Add(type);
            }
        }

        return derivedTypes;
    }

    public static bool IsSubclassOf(Type type, Type baseType)
    {
        if (type == null || baseType == null || type == baseType)
            return false;

        if (baseType.IsGenericType == false)
        {
            if (type.IsGenericType == false)
                return type.IsSubclassOf(baseType);
        }
        else
        {
            baseType = baseType.GetGenericTypeDefinition();
        }

        type = type.BaseType;
        Type objectType = typeof(object);

        while (type != objectType && type != null)
        {
            Type curentType = type.IsGenericType ?
                type.GetGenericTypeDefinition() : type;
            if (curentType == baseType)
                return true;

            type = type.BaseType;
        }

        return false;
    }
}

See also:

1 comment: