Tutorial 23: Frustum Culling

The three-dimensional viewing area on the screen where everything is drawn to is called the viewing frustum. Everything that is inside the frustum will be rendered to the screen by the video card. Everything that is outside of the frustum the video card will examine and then discard during the rendering process.

However, the process of depending on the video card to cull for us can be expensive if we have large scenes. For example, say we have a scene with 2000+ models that are 5,000 polygons each but only 10-20 are viewable at any given time. The video card has to examine every single triangle in all 2000 models to remove 1990 models from the scene just so we can draw 10 models. As you can see this is very inefficient.

How frustum culling solves our problem is that we can instead determine before rendering if a model is in our frustum or not. This saves us sending all the triangles to the video card and allows us to just send the triangles that need to be drawn. How we do this is that we put either a cube, a rectangle, or a sphere around each model and just calculate if that cube, rectangle, or sphere is viewable. The math to do that is usually only a couple lines of code which then removes the need to possibly test several thousand triangles.

To demonstrate how this works we will first create a scene with 25 randomly placed spheres. We will then rotate the camera manually to test culling of the spheres that are out of our view using the left and right arrow keys. We will also use a counter and display the number of spheres that are being drawn and not culled for confirmation. We will use code from several of the previous tutorials to create the scene.


Framework

The frame work has mostly classes from several of the previous tutorials. We do have three new classes called FrustumClass, PositionClass, and ModelListClass. FrustumClass will encapsulate the frustum culling ability this tutorial is focused on. ModelListClass will contain a list of the position information of the 25 spheres that will be randomly generated each time we run the program. PositionClass will handle the viewing rotation of the camera based on if the user is pressing the left or right arrow key.


Frustumclass.h

The header file for the FrustumClass is fairly simple. The class doesn't require any initialization or shutdown. Each frame the ConstructFrustum function is called after the camera has first been rendered. The ConstructFrustum function uses the private m_planes to calculate and store the six planes of the view frustum based on the updated viewing location. From there we can call any of the four check functions to see if either a point, cube, sphere, or rectangle are inside the viewing frustum or not.

////////////////////////////////////////////////////////////////////////////////
// Filename: frustumclass.h
////////////////////////////////////////////////////////////////////////////////
#ifndef _FRUSTUMCLASS_H_
#define _FRUSTUMCLASS_H_


//////////////
// INCLUDES //
//////////////
#include <directxmath.h>
using namespace DirectX;


////////////////////////////////////////////////////////////////////////////////
// Class name: FrustumClass
////////////////////////////////////////////////////////////////////////////////
class FrustumClass
{
public:
    FrustumClass();
    FrustumClass(const FrustumClass&);
    ~FrustumClass();

    void ConstructFrustum(XMMATRIX, XMMATRIX, float);

    bool CheckPoint(float, float, float);
    bool CheckCube(float, float, float, float);
    bool CheckSphere(float, float, float, float);
    bool CheckRectangle(float, float, float, float, float, float);

private:
    XMFLOAT4 m_planes[6];
};

#endif

Frustumclass.cpp

////////////////////////////////////////////////////////////////////////////////
// Filename: frustumclass.cpp
////////////////////////////////////////////////////////////////////////////////
#include "frustumclass.h"


FrustumClass::FrustumClass()
{
}


FrustumClass::FrustumClass(const FrustumClass& other)
{
}


FrustumClass::~FrustumClass()
{
}

ConstructFrustum is called every frame by the ApplicationClass. It passes in the depth of the screen, the view matrix, and the projection matrix. We then use these input variables to calculate the matrix of the view frustum at that frame. With the new frustum matrix, we then calculate the six planes that form the view frustum.

void FrustumClass::ConstructFrustum(XMMATRIX viewMatrix, XMMATRIX projectionMatrix, float screenDepth)
{
    XMMATRIX finalMatrix;
    XMFLOAT4X4 projMatrix, matrix;
    float zMinimum, r, t;


    // Load the projection matrix into a XMFLOAT4X4 structure.
    XMStoreFloat4x4(&projMatrix, projectionMatrix);

    // Calculate the minimum Z distance in the frustum.
    zMinimum = -projMatrix._43 / projMatrix._33;
    r = screenDepth / (screenDepth - zMinimum);
    projMatrix._33 = r;
    projMatrix._43 = -r * zMinimum;

    // Load the updated XMFLOAT4X4 back into the original projection matrix.
    projectionMatrix = XMLoadFloat4x4(&projMatrix);

    // Create the frustum matrix from the view matrix and updated projection matrix.
    finalMatrix = XMMatrixMultiply(viewMatrix, projectionMatrix);

    // Load the final matrix into a XMFLOAT4X4 structure.
    XMStoreFloat4x4(&matrix, finalMatrix);

    // Get the near plane of the frustum.
    m_planes[0].x = matrix._13;
    m_planes[0].y = matrix._23;
    m_planes[0].z = matrix._33;
    m_planes[0].w = matrix._43;

    // Normalize it.
    t = (float)sqrt((m_planes[0].x * m_planes[0].x) + (m_planes[0].y * m_planes[0].y) + (m_planes[0].z * m_planes[0].z));
    m_planes[0].x /= t;
    m_planes[0].y /= t;
    m_planes[0].z /= t;
    m_planes[0].w /= t;

    // Calculate the far plane of frustum.
    m_planes[1].x = matrix._14 - matrix._13; 
    m_planes[1].y = matrix._24 - matrix._23;
    m_planes[1].z = matrix._34 - matrix._33;
    m_planes[1].w = matrix._44 - matrix._43;

    // Normalize it.
    t = (float)sqrt((m_planes[1].x * m_planes[1].x) + (m_planes[1].y * m_planes[1].y) + (m_planes[1].z * m_planes[1].z));
    m_planes[1].x /= t;
    m_planes[1].y /= t;
    m_planes[1].z /= t;
    m_planes[1].w /= t;

    // Calculate the left plane of frustum.
    m_planes[2].x = matrix._14 + matrix._11; 
    m_planes[2].y = matrix._24 + matrix._21;
    m_planes[2].z = matrix._34 + matrix._31;
    m_planes[2].w = matrix._44 + matrix._41;

    // Normalize it.
    t = (float)sqrt((m_planes[2].x * m_planes[2].x) + (m_planes[2].y * m_planes[2].y) + (m_planes[2].z * m_planes[2].z));
    m_planes[2].x /= t;
    m_planes[2].y /= t;
    m_planes[2].z /= t;
    m_planes[2].w /= t;

    // Calculate the right plane of frustum.
    m_planes[3].x = matrix._14 - matrix._11; 
    m_planes[3].y = matrix._24 - matrix._21;
    m_planes[3].z = matrix._34 - matrix._31;
    m_planes[3].w = matrix._44 - matrix._41;

    // Normalize it.
    t = (float)sqrt((m_planes[3].x * m_planes[3].x) + (m_planes[3].y * m_planes[3].y) + (m_planes[3].z * m_planes[3].z));
    m_planes[3].x /= t;
    m_planes[3].y /= t;
    m_planes[3].z /= t;
    m_planes[3].w /= t;

    // Calculate the top plane of frustum.
    m_planes[4].x = matrix._14 - matrix._12; 
    m_planes[4].y = matrix._24 - matrix._22;
    m_planes[4].z = matrix._34 - matrix._32;
    m_planes[4].w = matrix._44 - matrix._42;

    // Normalize it.
    t = (float)sqrt((m_planes[4].x * m_planes[4].x) + (m_planes[4].y * m_planes[4].y) + (m_planes[4].z * m_planes[4].z));
    m_planes[4].x /= t;
    m_planes[4].y /= t;
    m_planes[4].z /= t;
    m_planes[4].w /= t;

    // Calculate the bottom plane of frustum.
    m_planes[5].x = matrix._14 + matrix._12;
    m_planes[5].y = matrix._24 + matrix._22;
    m_planes[5].z = matrix._34 + matrix._32;
    m_planes[5].w = matrix._44 + matrix._42;

    // Normalize it.
    t = (float)sqrt((m_planes[5].x * m_planes[5].x) + (m_planes[5].y * m_planes[5].y) + (m_planes[5].z * m_planes[5].z));
    m_planes[5].x /= t;
    m_planes[5].y /= t;
    m_planes[5].z /= t;
    m_planes[5].w /= t;

    return;
}

CheckPoint checks if a single point is inside the viewing frustum. This is the most general of the four checking algorithms but can be very efficient if used correctly in the right situation over the other checking methods. It takes the point and checks to see if it is inside all six planes. If the point is inside all six then it returns true, otherwise it returns false if not.

bool FrustumClass::CheckPoint(float x, float y, float z)
{
    int i;


    // Check if the point is inside all six planes of the view frustum.
    for(i=0; i<6; i++) 
    {
        if(((m_planes[i].x * x) + (m_planes[i].y * y) + (m_planes[i].z * z) + m_planes[i].w) < 0.0f)
        {
            return false;
        }
    }

    return true;
}

CheckCube checks if any of the eight corner points of the cube are inside the viewing frustum. It only requires as input the center point of the cube and the radius; it uses those to calculate the 8 corner points of the cube. It then checks if any one of the corner points are inside all 6 planes of the viewing frustum. If it does find a point inside all six planes of the viewing frustum it returns true, otherwise it returns false.

bool FrustumClass::CheckCube(float xCenter, float yCenter, float zCenter, float radius)
{
    int i;


    // Check if any one point of the cube is in the view frustum.
    for(i=0; i<6; i++)
    {
        if(m_planes[i].x * (xCenter - radius) +
           m_planes[i].y * (yCenter - radius) +
           m_planes[i].z * (zCenter - radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + radius) +
           m_planes[i].y * (yCenter - radius) +
           m_planes[i].z * (zCenter - radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - radius) +
           m_planes[i].y * (yCenter + radius) +
           m_planes[i].z * (zCenter - radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + radius) +
           m_planes[i].y * (yCenter + radius) +
           m_planes[i].z * (zCenter - radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - radius) +
           m_planes[i].y * (yCenter - radius) +
           m_planes[i].z * (zCenter + radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + radius) +
           m_planes[i].y * (yCenter - radius) +
           m_planes[i].z * (zCenter + radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - radius) +
           m_planes[i].y * (yCenter + radius) +
           m_planes[i].z * (zCenter + radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + radius) +
           m_planes[i].y * (yCenter + radius) +
           m_planes[i].z * (zCenter + radius) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        return false;
    }

    return true;
}

CheckSphere checks if the radius of the sphere from the center point is inside all six planes of the viewing frustum. If it is outside any of them then the sphere cannot be seen and the function will return false. If it is inside all six the function returns true that the sphere can be seen. This will be the function we use for checks in this tutorial.

bool FrustumClass::CheckSphere(float xCenter, float yCenter, float zCenter, float radius)
{
    int i;


    // Check if the radius of the sphere is inside the view frustum.
    for(i=0; i<6; i++)
    {
        if(((m_planes[i].x * xCenter) + (m_planes[i].y * yCenter) + (m_planes[i].z * zCenter) + m_planes[i].w) < -radius)
        {
            return false;
        }
    }

    return true;
}

CheckRectangle works the same as CheckCube except that that it takes as input the x radius, y radius, and z radius of the rectangle instead of just a single radius of a cube. It can then calculate the 8 corner points of the rectangle and do the frustum checks similar to the CheckCube function.

bool FrustumClass::CheckRectangle(float xCenter, float yCenter, float zCenter, float xSize, float ySize, float zSize)
{
    int i;


    // Check if any of the 6 planes of the rectangle are inside the view frustum.
    for(i=0; i<6; i++)
    {
        if(m_planes[i].x * (xCenter - xSize) +
           m_planes[i].y * (yCenter - ySize) +
           m_planes[i].z * (zCenter - zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + xSize) +
           m_planes[i].y * (yCenter - ySize) +
           m_planes[i].z * (zCenter - zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - xSize) +
           m_planes[i].y * (yCenter + ySize) +
           m_planes[i].z * (zCenter - zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - xSize) +
           m_planes[i].y * (yCenter - ySize) +
           m_planes[i].z * (zCenter + zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + xSize) +
           m_planes[i].y * (yCenter + ySize) +
           m_planes[i].z * (zCenter - zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + xSize) +
           m_planes[i].y * (yCenter - ySize) +
           m_planes[i].z * (zCenter + zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter - xSize) +
           m_planes[i].y * (yCenter + ySize) +
           m_planes[i].z * (zCenter + zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        if(m_planes[i].x * (xCenter + xSize) +
           m_planes[i].y * (yCenter + ySize) +
           m_planes[i].z * (zCenter + zSize) + m_planes[i].w >= 0.0f)
        {
            continue;
        }

        return false;
    }

    return true;
}

Modellistclass.h

ModelListClass is a new class for maintaining information about all the models in the scene. For this tutorial it only maintains the position of the sphere models since we only have one model type. This class can be expanded to maintain all the different types of models in the scene and indexes to their ModelClass but I am keeping this tutorial simple for now.

////////////////////////////////////////////////////////////////////////////////
// Filename: modellistclass.h
////////////////////////////////////////////////////////////////////////////////
#ifndef _MODELLISTCLASS_H_
#define _MODELLISTCLASS_H_


//////////////
// INCLUDES //
//////////////
#include <stdlib.h>
#include <time.h>


///////////////////////////////////////////////////////////////////////////////
// Class name: ModelListClass
///////////////////////////////////////////////////////////////////////////////
class ModelListClass
{
private:
    struct ModelInfoType
    {
        float positionX, positionY, positionZ;
    };

public:
    ModelListClass();
    ModelListClass(const ModelListClass&);
    ~ModelListClass();

    void Initialize(int);
    void Shutdown();

    int GetModelCount();
    void GetData(int, float&, float&, float&);

private:
    int m_modelCount;
    ModelInfoType* m_ModelInfoList;
};

#endif

Modellistclass.cpp

////////////////////////////////////////////////////////////////////////////////
// Filename: modellistclass.cpp
////////////////////////////////////////////////////////////////////////////////
#include "modellistclass.h"


ModelListClass::ModelListClass()
{
    m_ModelInfoList = 0;
}


ModelListClass::ModelListClass(const ModelListClass& other)
{
}


ModelListClass::~ModelListClass()
{
}


void ModelListClass::Initialize(int numModels)
{
    int i;

First store the number of models that will be used and then create the list array of them using the ModelInfoType structure.

    // Store the number of models.
    m_modelCount = numModels;

    // Create a list array of the model information.
    m_ModelInfoList = new ModelInfoType[m_modelCount];

Seed the random number generator with the current time and then randomly generate the position of the models and store them in the list array.

    // Seed the random generator with the current time.
    srand((unsigned int)time(NULL));

    // Go through all the models and randomly generate the position.
    for(i=0; i<m_modelCount; i++)
    {
        // Generate a random position in front of the viewer for the mode.
        m_ModelInfoList[i].positionX = (((float)rand() - (float)rand()) / RAND_MAX) * 10.0f;
        m_ModelInfoList[i].positionY = (((float)rand() - (float)rand()) / RAND_MAX) * 10.0f;
        m_ModelInfoList[i].positionZ = ((((float)rand() - (float)rand()) / RAND_MAX) * 10.0f) + 5.0f;
    }

    return;
}

The Shutdown function releases the model information list array.

void ModelListClass::Shutdown()
{
    // Release the model information list.
    if(m_ModelInfoList)
    {
        delete [] m_ModelInfoList;
        m_ModelInfoList = 0;
    }

    return;
}

GetModelCount returns the number of models that this class maintains information about.

int ModelListClass::GetModelCount()
{
    return m_modelCount;
}

The GetData function extracts the position of a model at the given input index location.

void ModelListClass::GetData(int index, float& positionX, float& positionY, float& positionZ)
{
    positionX = m_ModelInfoList[index].positionX;
    positionY = m_ModelInfoList[index].positionY;
    positionZ = m_ModelInfoList[index].positionZ;
    return;
}

Positionclass.h

To allow for camera movement by using the left and right arrow key in this tutorial we create a new class to calculate and maintain the position of the viewer. This class will only handle turning left and right for now but can be expanded to maintain all different movement changes. The movement also includes acceleration and deceleration to create a smooth camera effect.

////////////////////////////////////////////////////////////////////////////////
// Filename: positionclass.h
////////////////////////////////////////////////////////////////////////////////
#ifndef _POSITIONCLASS_H_
#define _POSITIONCLASS_H_


//////////////
// INCLUDES //
//////////////
#include <math.h>


////////////////////////////////////////////////////////////////////////////////
// Class name: PositionClass
////////////////////////////////////////////////////////////////////////////////
class PositionClass
{
public:
    PositionClass();
    PositionClass(const PositionClass&);
    ~PositionClass();

    void SetFrameTime(float);
    void GetRotation(float&);

    void TurnLeft(bool);
    void TurnRight(bool);

private:
    float m_frameTime;
    float m_rotationY;
    float m_leftTurnSpeed, m_rightTurnSpeed;
};

#endif

Positionclass.cpp

////////////////////////////////////////////////////////////////////////////////
// Filename: positionclass.cpp
////////////////////////////////////////////////////////////////////////////////
#include "positionclass.h"

The class constructor initializes the private member variables to zero to start with.

PositionClass::PositionClass()
{
    m_frameTime = 0.0f;
    m_rotationY = 0.0f;
    m_leftTurnSpeed  = 0.0f;
    m_rightTurnSpeed = 0.0f;
}


PositionClass::PositionClass(const PositionClass& other)
{
}


PositionClass::~PositionClass()
{
}

The SetFrameTime function is used to set the frame speed in this class. PositionClass will use that frame time speed to calculate how fast the viewer should be moving and rotating. This function should always be called at the beginning of each frame before using this class to move the viewing position.

void PositionClass::SetFrameTime(float time)
{
    m_frameTime = time;
    return;
}

GetRotation returns the Y-axis rotation of the viewer. This is the only helper function we need for this tutorial but could be expanded to get more information about the location of the viewer.

void PositionClass::GetRotation(float& y)
{
    y = m_rotationY;
    return;
}

The movement functions both work the same. Both functions are called each frame. The keydown input variable to each function indicates if the user is pressing the left key or the right key. If they are pressing the key then each frame the speed will accelerate until it hits a maximum. This way the camera speeds up similar to the acceleration in a vehicle creating the effect of smooth movement and high responsiveness. Likewise, if the user releases the key and the keydown variable is false it will then smoothly slow down each frame until the speed hits zero. The speed is calculated against the frame time to ensure the movement speed remains the same regardless of the frame rate. Each function then uses some basic math to calculate the new position of the camera.

void PositionClass::TurnLeft(bool keydown)
{
    // If the key is pressed increase the speed at which the camera turns left.  If not slow down the turn speed.
    if(keydown)
    {
        m_leftTurnSpeed += m_frameTime * 1.5f;

        if(m_leftTurnSpeed > (m_frameTime * 200.0f))
        {
            m_leftTurnSpeed = m_frameTime * 200.0f;
        }
    }
    else
    {
        m_leftTurnSpeed -= m_frameTime* 1.0f;

        if(m_leftTurnSpeed < 0.0f)
        {
            m_leftTurnSpeed = 0.0f;
        }
    }

    // Update the rotation using the turning speed.
    m_rotationY -= m_leftTurnSpeed;
    if(m_rotationY < 0.0f)
    {
        m_rotationY += 360.0f;
    }

    return;
}


void PositionClass::TurnRight(bool keydown)
{
    // If the key is pressed increase the speed at which the camera turns right.  If not slow down the turn speed.
    if(keydown)
    {
        m_rightTurnSpeed += m_frameTime * 1.5f;

        if(m_rightTurnSpeed > (m_frameTime * 200.0f))
        {
            m_rightTurnSpeed = m_frameTime * 200.0f;
        }
    }
    else
    {
        m_rightTurnSpeed -= m_frameTime* 1.0f;

        if(m_rightTurnSpeed < 0.0f)
        {
            m_rightTurnSpeed = 0.0f;
        }
    }

    // Update the rotation using the turning speed.
    m_rotationY += m_rightTurnSpeed;
    if(m_rotationY > 360.0f)
    {
        m_rotationY -= 360.0f;
    }

    return;
}

Inputclass.h

The InputClass was modified for this tutorial to add left and right key presses.

////////////////////////////////////////////////////////////////////////////////
// Filename: inputclass.h
////////////////////////////////////////////////////////////////////////////////
#ifndef _INPUTCLASS_H_
#define _INPUTCLASS_H_


///////////////////////////////
// PRE-PROCESSING DIRECTIVES //
///////////////////////////////
#define DIRECTINPUT_VERSION 0x0800


/////////////
// LINKING //
/////////////
#pragma comment(lib, "dinput8.lib")
#pragma comment(lib, "dxguid.lib")


//////////////
// INCLUDES //
//////////////
#include <dinput.h>


////////////////////////////////////////////////////////////////////////////////
// Class name: InputClass
////////////////////////////////////////////////////////////////////////////////
class InputClass
{
public:
    InputClass();
    InputClass(const InputClass&);
    ~InputClass();

    bool Initialize(HINSTANCE, HWND, int, int);
    void Shutdown();
    bool Frame();

    bool IsEscapePressed();
    bool IsLeftArrowPressed();
    bool IsRightArrowPressed();

    void GetMouseLocation(int&, int&);
    bool IsMousePressed();

private:
    bool ReadKeyboard();
    bool ReadMouse();
    void ProcessInput();

private:
    IDirectInput8* m_directInput;
    IDirectInputDevice8* m_keyboard;
    IDirectInputDevice8* m_mouse;
    unsigned char m_keyboardState[256];
    DIMOUSESTATE m_mouseState;
    int m_screenWidth, m_screenHeight, m_mouseX, m_mouseY;
};

#endif

Inputclass.cpp

////////////////////////////////////////////////////////////////////////////////
// Filename: inputclass.cpp
////////////////////////////////////////////////////////////////////////////////
#include "inputclass.h"


InputClass::InputClass()
{
    m_directInput = 0;
    m_keyboard = 0;
    m_mouse = 0;
}


InputClass::InputClass(const InputClass& other)
{
}


InputClass::~InputClass()
{
}


bool InputClass::Initialize(HINSTANCE hinstance, HWND hwnd, int screenWidth, int screenHeight)
{
    HRESULT result;


    // Store the screen size which will be used for positioning the mouse cursor.
    m_screenWidth = screenWidth;
    m_screenHeight = screenHeight;

    // Initialize the location of the mouse on the screen.
    m_mouseX = 0;
    m_mouseY = 0;

    // Initialize the main direct input interface.
    result = DirectInput8Create(hinstance, DIRECTINPUT_VERSION, IID_IDirectInput8, (void**)&m_directInput, NULL);
    if(FAILED(result))
    {
        return false;
    }

    // Initialize the direct input interface for the keyboard.
    result = m_directInput->CreateDevice(GUID_SysKeyboard, &m_keyboard, NULL);
    if(FAILED(result))
    {
        return false;
    }

    // Set the data format.  In this case since it is a keyboard we can use the predefined data format.
    result = m_keyboard->SetDataFormat(&c_dfDIKeyboard);
    if(FAILED(result))
    {
        return false;
    }

    // Set the cooperative level of the keyboard to not share with other programs.
    result = m_keyboard->SetCooperativeLevel(hwnd, DISCL_FOREGROUND | DISCL_EXCLUSIVE);
    if(FAILED(result))
    {
        return false;
    }

    // Now acquire the keyboard.
    result = m_keyboard->Acquire();
    if(FAILED(result))
    {
        return false;
    }

    // Initialize the direct input interface for the mouse.
    result = m_directInput->CreateDevice(GUID_SysMouse, &m_mouse, NULL);
    if(FAILED(result))
    {
        return false;
    }

    // Set the data format for the mouse using the pre-defined mouse data format.
    result = m_mouse->SetDataFormat(&c_dfDIMouse);
    if(FAILED(result))
    {
        return false;
    }

    // Set the cooperative level of the mouse to share with other programs.
    result = m_mouse->SetCooperativeLevel(hwnd, DISCL_FOREGROUND | DISCL_NONEXCLUSIVE);
    if(FAILED(result))
    {
        return false;
    }

    // Acquire the mouse.
    result = m_mouse->Acquire();
    if(FAILED(result))
    {
        return false;
    }

    return true;
}


void InputClass::Shutdown()
{
    // Release the mouse.
    if(m_mouse)
    {
        m_mouse->Unacquire();
        m_mouse->Release();
        m_mouse = 0;
    }

    // Release the keyboard.
    if(m_keyboard)
    {
        m_keyboard->Unacquire();
        m_keyboard->Release();
        m_keyboard = 0;
    }

    // Release the main interface to direct input.
    if(m_directInput)
    {
        m_directInput->Release();
        m_directInput = 0;
    }

    return;
}


bool InputClass::Frame()
{
    bool result;


    // Read the current state of the keyboard.
    result = ReadKeyboard();
    if(!result)
    {
        return false;
    }

    // Read the current state of the mouse.
    result = ReadMouse();
    if(!result)
    {
        return false;
    }

    // Process the changes in the mouse and keyboard.
    ProcessInput();

    return true;
}


bool InputClass::ReadKeyboard()
{
    HRESULT result;


    // Read the keyboard device.
    result = m_keyboard->GetDeviceState(sizeof(m_keyboardState), (LPVOID)&m_keyboardState);
    if(FAILED(result))
    {
        // If the keyboard lost focus or was not acquired then try to get control back.
        if((result == DIERR_INPUTLOST) || (result == DIERR_NOTACQUIRED))
        {
            m_keyboard->Acquire();
        }
        else
        {
            return false;
        }
    }
		
    return true;
}


bool InputClass::ReadMouse()
{
    HRESULT result;


    // Read the mouse device.
    result = m_mouse->GetDeviceState(sizeof(DIMOUSESTATE), (LPVOID)&m_mouseState);
    if(FAILED(result))
    {
        // If the mouse lost focus or was not acquired then try to get control back.
        if((result == DIERR_INPUTLOST) || (result == DIERR_NOTACQUIRED))
        {
            m_mouse->Acquire();
        }
        else
        {
            return false;
        }
    }

    return true;
}


void InputClass::ProcessInput()
{
    // Update the location of the mouse cursor based on the change of the mouse location during the frame.
    m_mouseX += m_mouseState.lX;
    m_mouseY += m_mouseState.lY;

    // Ensure the mouse location doesn't exceed the screen width or height.
    if(m_mouseX < 0)  { m_mouseX = 0; }
    if(m_mouseY < 0)  { m_mouseY = 0; }
	
    if(m_mouseX > m_screenWidth)  { m_mouseX = m_screenWidth; }
    if(m_mouseY > m_screenHeight) { m_mouseY = m_screenHeight; }
	
    return;
}


bool InputClass::IsEscapePressed()
{
    // Do a bitwise and on the keyboard state to check if the escape key is currently being pressed.
    if(m_keyboardState[DIK_ESCAPE] & 0x80)
    {
        return true;
    }

    return false;
}

We add two new functions for checking if the left or right arrow keys are pressed.

bool InputClass::IsLeftArrowPressed()
{
    if(m_keyboardState[DIK_LEFT] & 0x80)
    {
        return true;
    }

    return false;
}


bool InputClass::IsRightArrowPressed()
{
    if(m_keyboardState[DIK_RIGHT] & 0x80)
    {
        return true;
    }

    return false;
}


void InputClass::GetMouseLocation(int& mouseX, int& mouseY)
{
    mouseX = m_mouseX;
    mouseY = m_mouseY;
    return;
}


bool InputClass::IsMousePressed()
{
    // Check the left mouse button state.
    if(m_mouseState.rgbButtons[0] & 0x80)
    {
        return true;
    }

    return false;
}

Applicationclass.h

The ApplicationClass for this tutorial includes a number of classes we have used in the previous tutorials. It also includes the frustumclass.h, positionclass.h, and modellistclass.h headers which are new.

////////////////////////////////////////////////////////////////////////////////
// Filename: applicationclass.h
////////////////////////////////////////////////////////////////////////////////
#ifndef _APPLICATIONCLASS_H_
#define _APPLICATIONCLASS_H_


///////////////////////
// MY CLASS INCLUDES //
///////////////////////
#include "d3dclass.h"
#include "inputclass.h"
#include "cameraclass.h"
#include "modelclass.h"
#include "lightclass.h"
#include "lightshaderclass.h"
#include "fontshaderclass.h"
#include "fontclass.h"
#include "textclass.h"
#include "modellistclass.h"
#include "timerclass.h"
#include "positionclass.h"
#include "frustumclass.h"


/////////////
// GLOBALS //
/////////////
const bool FULL_SCREEN = false;
const bool VSYNC_ENABLED = true;
const float SCREEN_DEPTH = 1000.0f;
const float SCREEN_NEAR = 0.3f;


////////////////////////////////////////////////////////////////////////////////
// Class name: ApplicationClass
////////////////////////////////////////////////////////////////////////////////
class ApplicationClass
{
public:
    ApplicationClass();
    ApplicationClass(const ApplicationClass&);
    ~ApplicationClass();

    bool Initialize(int, int, HWND);
    void Shutdown();
    bool Frame(InputClass*);

private:
    bool Render();
    bool UpdateRenderCountString(int);

private:
    D3DClass* m_Direct3D;
    CameraClass* m_Camera;
    ModelClass* m_Model;
    LightClass* m_Light;
    LightShaderClass* m_LightShader;
    FontShaderClass* m_FontShader;
    FontClass* m_Font;
    TextClass* m_RenderCountString;
    ModelListClass* m_ModelList;
    TimerClass* m_Timer;
    PositionClass* m_Position;
    FrustumClass* m_Frustum;
    XMMATRIX m_baseViewMatrix;
};

#endif

Applicationclass.cpp

////////////////////////////////////////////////////////////////////////////////
// Filename: applicationclass.cpp
////////////////////////////////////////////////////////////////////////////////
#include "applicationclass.h"


ApplicationClass::ApplicationClass()
{
    m_Direct3D = 0;
    m_Camera = 0;
    m_Model = 0;
    m_Light = 0;
    m_LightShader = 0;
    m_FontShader = 0;
    m_Font = 0;
    m_RenderCountString = 0;
    m_ModelList = 0;
    m_Timer = 0;
    m_Position = 0;
    m_Frustum = 0;
}


ApplicationClass::ApplicationClass(const ApplicationClass& other)
{
}


ApplicationClass::~ApplicationClass()
{
}


bool ApplicationClass::Initialize(int screenWidth, int screenHeight, HWND hwnd)
{
    char modelFilename[128], textureFilename1[128], renderString[32];
    bool result;


    // Create and initialize the Direct3D object.
    m_Direct3D = new D3DClass;

    result = m_Direct3D->Initialize(screenWidth, screenHeight, VSYNC_ENABLED, hwnd, FULL_SCREEN, SCREEN_DEPTH, SCREEN_NEAR);
    if(!result)
    {
        MessageBox(hwnd, L"Could not initialize Direct3D", L"Error", MB_OK);
        return false;
    }

We create the camera object as we normally do. However, we will take a copy of the view matrix from the start since we will be modifying the view matrix every time the camera turns, and we need this unmodified view matrix to render the text in the same location each frame.

    // Create and initialize the camera object.
    m_Camera = new CameraClass;

    m_Camera->SetPosition(0.0f, 0.0f, -10.0f);
    m_Camera->Render();
    m_Camera->GetViewMatrix(m_baseViewMatrix);

We will load a sphere model for this tutorial.

    // Set the file name of the model.
    strcpy_s(modelFilename, "../Engine/data/sphere.txt");

    // Set the file name of the textures.
    strcpy_s(textureFilename1, "../Engine/data/stone01.tga");

    // Create and initialize the model object.
    m_Model = new ModelClass;

    result = m_Model->Initialize(m_Direct3D->GetDevice(), m_Direct3D->GetDeviceContext(), modelFilename, textureFilename1);
    if(!result)
    {
        return false;
    }

We will use just the basic light shader to render the spheres.

    // Create and initialize the light object.
    m_Light = new LightClass;

    m_Light->SetDiffuseColor(1.0f, 1.0f, 1.0f, 1.0f);
    m_Light->SetDirection(0.0f, 0.0f, 1.0f);

    // Create and initialize the light shader object.
    m_LightShader = new LightShaderClass;

    result = m_LightShader->Initialize(m_Direct3D->GetDevice(), hwnd);
    if(!result)
    {
        return false;
    }

We need to render the number of spheres we are rendering to the screen using a FontClass, FontShaderClass, and TextClass object.

    // Create and initialize the font shader object.
    m_FontShader = new FontShaderClass;

    result = m_FontShader->Initialize(m_Direct3D->GetDevice(), hwnd);
    if(!result)
    {
        MessageBox(hwnd, L"Could not initialize the font shader object.", L"Error", MB_OK);
        return false;
    }

    // Create and initialize the font object.
    m_Font = new FontClass;

    result = m_Font->Initialize(m_Direct3D->GetDevice(), m_Direct3D->GetDeviceContext(), 0);
    if(!result)
    {
        return false;
    }	
	
    // Set the initial render count string.
    strcpy_s(renderString, "Render Count: 0");

    // Create and initialize the text object for the render count string.
    m_RenderCountString = new TextClass;

    result = m_RenderCountString->Initialize(m_Direct3D->GetDevice(), m_Direct3D->GetDeviceContext(), screenWidth, screenHeight, 32, m_Font, renderString, 10, 10, 1.0f, 1.0f, 1.0f);
    if(!result)
    {
        return false;
    }

Here we create the new ModelListClass object and have it create 25 randomly placed sphere models.

    // Create and initialize the model list object.
    m_ModelList = new ModelListClass;
    m_ModelList->Initialize(25);

We create a timer also since we want smooth movement when moving the camera around according to the frame time.

    // Create and initialize the timer object.
    m_Timer = new TimerClass;

    result = m_Timer->Initialize();
    if(!result)
    {
        return false;
    }

Here we create the new PositionClass object for maintaining where the viewer is positioned and where they are looking at.

    // Create the position class object.
    m_Position = new PositionClass;

And finally, we create our new FrustumClass object here.

    // Create the frustum class object.
    m_Frustum = new FrustumClass;

    return true;
}


void ApplicationClass::Shutdown()
{
    // Release the frustum class object.
    if(m_Frustum)
    {
        delete m_Frustum;
        m_Frustum = 0;
    }

    // Release the position object.
    if(m_Position)
    {
        delete m_Position;
        m_Position = 0;
    }

    // Release the timer object.
    if(m_Timer)
    {
        delete m_Timer;
        m_Timer = 0;
    }

    // Release the model list object.
    if(m_ModelList)
    {
        m_ModelList->Shutdown();
        delete m_ModelList;
        m_ModelList = 0;
    }

    // Release the text objects for the render count string.
    if(m_RenderCountString)
    {
        m_RenderCountString->Shutdown();
        delete m_RenderCountString;
        m_RenderCountString = 0;
    }

    // Release the font object.
    if(m_Font)
    {
        m_Font->Shutdown();
        delete m_Font;
        m_Font = 0;
    }

    // Release the font shader object.
    if(m_FontShader)
    {
        m_FontShader->Shutdown();
        delete m_FontShader;
        m_FontShader = 0;
    }

    // Release the light shader object.
    if(m_LightShader)
    {
        m_LightShader->Shutdown();
        delete m_LightShader;
        m_LightShader = 0;
    }

    // Release the light object.
    if(m_Light)
    {
        delete m_Light;
        m_Light = 0;
    }

    // Release the model object.
    if(m_Model)
    {
        m_Model->Shutdown();
        delete m_Model;
        m_Model = 0;
    }

    // Release the camera object.
    if(m_Camera)
    {
        delete m_Camera;
        m_Camera = 0;
    }

    // Release the Direct3D object.
    if(m_Direct3D)
    {
        m_Direct3D->Shutdown();
        delete m_Direct3D;
        m_Direct3D = 0;
    }

    return;
}


bool ApplicationClass::Frame(InputClass* Input)
{
    float rotationY;
    bool result, keyDown;

Each frame we will update the time and then send that into the PositionClass object so that it can update any movement it needs to using the current frame time.

    // Update the system stats.
    m_Timer->Frame();

    // Check if the user pressed escape and wants to exit the application.
    if(Input->IsEscapePressed())
    {
        return false;
    }

    // Set the frame time for calculating the updated position.
    m_Position->SetFrameTime(m_Timer->GetTime());

Here is where we check if either the left or right arrow keys are pressed or not. We send in the result to the PositionClass object so it can accelerate or decelerate the left or right movement.

    // Check if the left or right arrow key has been pressed, if so rotate the camera accordingly.
    keyDown = Input->IsLeftArrowPressed();
    m_Position->TurnLeft(keyDown);

    keyDown = Input->IsRightArrowPressed();
    m_Position->TurnRight(keyDown);

Each frame we get the current rotation of the camera and then update the CameraClass object with the new rotation. This way we can provide our shaders the updated viewMatrix each frame.

    // Get the current view point rotation.
    m_Position->GetRotation(rotationY);

    // Set the rotation of the camera.
    m_Camera->SetRotation(0.0f, rotationY, 0.0f);
    m_Camera->Render();

    // Render the graphics scene.
    result = Render();
    if(!result)
    {
        return false;
    }

    return true;
}


bool ApplicationClass::Render()
{
    XMMATRIX worldMatrix, viewMatrix, projectionMatrix, orthoMatrix;
    float positionX, positionY, positionZ, radius;
    int modelCount, renderCount, i;
    bool renderModel, result;


    // Clear the buffers to begin the scene.
    m_Direct3D->BeginScene(0.0f, 0.0f, 0.0f, 1.0f);

    // Get the world, view, and projection matrices from the camera and d3d objects.
    m_Direct3D->GetWorldMatrix(worldMatrix);
    m_Camera->GetViewMatrix(viewMatrix);
    m_Direct3D->GetProjectionMatrix(projectionMatrix);
    m_Direct3D->GetOrthoMatrix(orthoMatrix);

The major change to the Render function is that we now construct the viewing frustum each frame based on the updated viewing matrix. This construction has to occur each time the view matrix changes or the frustum culling checks we do will not be correct.

    // Construct the frustum.
    m_Frustum->ConstructFrustum(viewMatrix, projectionMatrix, SCREEN_DEPTH);

    // Get the number of models that will be rendered.
    modelCount = m_ModelList->GetModelCount();

    // Initialize the count of models that have been rendered.
    renderCount = 0;

Now loop through all the models in the ModelListClass object.

    // Go through all the models and render them only if they can be seen by the camera view.
    for(i=0; i<modelCount; i++)
    {
        // Get the position and color of the sphere model at this index.
        m_ModelList->GetData(i, positionX, positionY, positionZ);

        // Set the radius of the sphere to 1.0 since this is already known.
        radius = 1.0f;

Here is where we use the new FrustumClass object. We check if the sphere is viewable in the viewing frustum. If it can be seen we render it, if it cannot be seen we skip it and check the next one. This is where we will gain all the speed by using frustum culling.

        // Check if the sphere model is in the view frustum.
        renderModel = m_Frustum->CheckSphere(positionX, positionY, positionZ, radius);

        // If it can be seen then render it, if not skip this model and check the next sphere.
        if(renderModel)
        {
            // Move the model to the location it should be rendered at.
            worldMatrix = XMMatrixTranslation(positionX, positionY, positionZ);

            // Render the model using the light shader.
            m_Model->Render(m_Direct3D->GetDeviceContext());

            result = m_LightShader->Render(m_Direct3D->GetDeviceContext(), m_Model->GetIndexCount(), worldMatrix, viewMatrix, projectionMatrix,
                                           m_Model->GetTexture(), m_Light->GetDirection(), m_Light->GetDiffuseColor());
            if(!result)
            {
                return false;
            }

            // Since this model was rendered then increase the count for this frame.
            renderCount++;
        }
    }

    // Update the render count text.
    result = UpdateRenderCountString(renderCount);
    if(!result)
    {
        return false;
    }

We use the UpdateRenderCountString function and the TextClass to display how many spheres were actually rendered. We can also infer from this number that the spheres that were not rendered were instead culled using the new FrustumClass object.

    // Disable the Z buffer and enable alpha blending for 2D rendering.
    m_Direct3D->TurnZBufferOff();
    m_Direct3D->EnableAlphaBlending();

    // Reset the world matrix.
    m_Direct3D->GetWorldMatrix(worldMatrix);

    // Render the render count text string using the font shader.
    m_RenderCountString->Render(m_Direct3D->GetDeviceContext());

    result = m_FontShader->Render(m_Direct3D->GetDeviceContext(), m_RenderCountString->GetIndexCount(), worldMatrix, m_baseViewMatrix, orthoMatrix, 
                                  m_Font->GetTexture(), m_RenderCountString->GetPixelColor());
    if(!result)
    {
        return false;
    }	

    // Enable the Z buffer and disable alpha blending now that 2D rendering is complete.
    m_Direct3D->TurnZBufferOn();
    m_Direct3D->DisableAlphaBlending();

    // Present the rendered scene to the screen.
    m_Direct3D->EndScene();

    return true;
}

The UpdateRenderCountString function will take the integer render count and modify the text string so that it can be rendered on the screen.

bool ApplicationClass::UpdateRenderCountString(int renderCount)
{
    char tempString[16], finalString[32];
    bool result;


    // Convert the render count integer to string format.
    sprintf_s(tempString, "%d", renderCount);

    // Setup the render count string.
    strcpy_s(finalString, "Render Count: ");
    strcat_s(finalString, tempString);

    // Update the sentence vertex buffer with the new string information.
    result = m_RenderCountString->UpdateText(m_Direct3D->GetDeviceContext(), m_Font, finalString, 10, 10, 1.0f, 1.0f, 1.0f);
    if(!result)
    {
        return false;
    }

    return true;
}

Summary

Now you have seen how to cull objects. The only trick from here is determining whether a cube, rectangle, sphere, or clever use of a point is better for culling your different objects.


To Do Exercises

1. Recompile and run the program. Use the left and right arrow key to move the camera and update the render count in the upper left corner.

2. Load the cube model instead and change the cull check to CheckCube.

3. Create some different models and test which of the culling checks works best for them.


Source Code

Source Code and Data Files: dx11win10tut23_src.zip

Back to Tutorial Index