A DECISION TREE IMPLEMENTATION IN JAVA

CONTENTS

1. Introduction
2. A Generic Binary Decision Tree Generator and Query System



1. INTRODUCTION

Decision trees are an important structure used in many branches of Computer Science (e.g. Classification, Artificial Intelligence etc.).

The tree comprises a set of nodes commencing with a single root node and terminating at a set of leaf nodes, in between are located body nodes. The root and each body node has connections (called arcs or edges) to at least two other body or leaf nodes. A tree where the root and body nodes have uniformly two arcs each is called a binary tree.

At its simplest the leaf nodes in a decision tree represent a set of terminating "answers", the root and body nodes then represent "questions". The user arrives at an answer by providing responses to the questions. The nature of the response to a particular question then dictates which arc should be followed to the next question (or answer if a leaf node is arrived at). In the case of a binary decision tree the responses are typically Yes or No, each corresponding to one of the two available branches at each question node.

From an implementational view point a binary tree is like a linked list but with an extra "link".




2. A GENERIC BINARY DECISION TREE GENERATOR AND QUERY SYSTEM

Decision trees are so common that it would seem to be a useful expedient to write a Java program that builds and queries such trees. The class presented in Table 1 does this with respect to binary decision trees. The class has, as one of its fields, another class (an inner class) which defines a node in a generic decision tree. This node class (BinTree) has four field:

  1. nodeID: An identification number which serves no other purpose than to aid the tree construction process.
  2. questOrAns: the "question" in the case of a root or body node, or "answer" in the case of a leaf node.
  3. yesBranch: Another instance of the class BinTree describing the "Yes" branch (this would be set to null in the case of a leaf node).
  4. noBranch: Same as yesBranch field but describing the "No" branch.

The code has three groups of methods:

  1. Tree generation methods: A method to create the root node and then further methods to add body and leaf nodes. We add a body or leaf node by providing the nodeID number for the node we wish to add to. The existing tree is then searched until this node is found (in which case we add the new node), or there are no more nodes in the tree (in which case we report an error). Note that: (1) the tree is built commencing from the root node, and (2) it is up to the user to ensure that the tree is properly "balanced".
  2. Output methods: A set of methods to output the tree once it has been built. This is really for diagnostic purposes only. Note that nodes are number according to there level in the tree and whether they are a "Yes" (1) or "No" (2) nodes.
  3. Query methods: A set of methods to facilitate querying of the tree.
// DECISION TREE
// Frans Coenen
// Thursday 15 August 2002
// Department of Computer Science, University of Liverpool

import java.io.*;

class DecisionTree {

    /* ------------------------------- */
    /*                                 */
    /*              FIELDS             */
    /*                                 */
    /* ------------------------------- */

    /* NESTED CLASS */

    private class BinTree {
    	
	/* FIELDS */
	
	private int     nodeID;
    	private String  questOrAns = null;
    	private BinTree yesBranch  = null;
    	private BinTree noBranch   = null;
	
	/* CONSTRUCTOR */
	
	public BinTree(int newNodeID, String newQuestAns) {
	    nodeID     = newNodeID;
	    questOrAns = newQuestAns;
            }
	}

    /* OTHER FIELDS */

    static BufferedReader    keyboardInput = new
                           BufferedReader(new InputStreamReader(System.in));
    BinTree rootNode = null;

    /* ------------------------------------ */
    /*                                      */
    /*              CONSTRUCTORS            */
    /*                                      */
    /* ------------------------------------ */

    /* Default Constructor */

    public DecisionTree() {
	}

    /* ----------------------------------------------- */
    /*                                                 */
    /*               TREE BUILDING METHODS             */
    /*                                                 */
    /* ----------------------------------------------- */

    /* CREATE ROOT NODE */

    public void createRoot(int newNodeID, String newQuestAns) {
	rootNode = new BinTree(newNodeID,newQuestAns);	
	System.out.println("Created root node " + newNodeID);	
	}
			
    /* ADD YES NODE */

    public void addYesNode(int existingNodeID, int newNodeID, String newQuestAns) {
	// If no root node do nothing
	
	if (rootNode == null) {
	    System.out.println("ERROR: No root node!");
	    return;
	    }
	
	// Search tree
	
	if (searchTreeAndAddYesNode(rootNode,existingNodeID,newNodeID,newQuestAns)) {
	    System.out.println("Added node " + newNodeID +
	    		" onto \"yes\" branch of node " + existingNodeID);
	    }
	else System.out.println("Node " + existingNodeID + " not found");
	}

    /* SEARCH TREE AND ADD YES NODE */

    private boolean searchTreeAndAddYesNode(BinTree currentNode,
    			int existingNodeID, int newNodeID, String newQuestAns) {
    	if (currentNode.nodeID == existingNodeID) {
	    // Found node
	    if (currentNode.yesBranch == null) currentNode.yesBranch = new
	    		BinTree(newNodeID,newQuestAns);
	    else {
	        System.out.println("WARNING: Overwriting previous node " +
			"(id = " + currentNode.yesBranch.nodeID +
			") linked to yes branch of node " +
			existingNodeID);
		currentNode.yesBranch = new BinTree(newNodeID,newQuestAns);
		}		
    	    return(true);
	    }
	else {
	    // Try yes branch if it exists
	    if (currentNode.yesBranch != null) { 	
	        if (searchTreeAndAddYesNode(currentNode.yesBranch,
		        	existingNodeID,newNodeID,newQuestAns)) {    	
	            return(true);
		    }	
		else {
    	        // Try no branch if it exists
	    	    if (currentNode.noBranch != null) {
    	    		return(searchTreeAndAddYesNode(currentNode.noBranch,
				existingNodeID,newNodeID,newQuestAns));
			}
		    else return(false);	// Not found here
		    }
    		}
	    return(false);		// Not found here
	    }
   	} 	
    		
    /* ADD NO NODE */

    public void addNoNode(int existingNodeID, int newNodeID, String newQuestAns) {
	// If no root node do nothing
	
	if (rootNode == null) {
	    System.out.println("ERROR: No root node!");
	    return;
	    }
	
	// Search tree
	
	if (searchTreeAndAddNoNode(rootNode,existingNodeID,newNodeID,newQuestAns)) {
	    System.out.println("Added node " + newNodeID +
	    		" onto \"no\" branch of node " + existingNodeID);
	    }
	else System.out.println("Node " + existingNodeID + " not found");
	}
	
    /* SEARCH TREE AND ADD NO NODE */

    private boolean searchTreeAndAddNoNode(BinTree currentNode,
    			int existingNodeID, int newNodeID, String newQuestAns) {
    	if (currentNode.nodeID == existingNodeID) {
	    // Found node
	    if (currentNode.noBranch == null) currentNode.noBranch = new
	    		BinTree(newNodeID,newQuestAns);
	    else {
	        System.out.println("WARNING: Overwriting previous node " +
			"(id = " + currentNode.noBranch.nodeID +
			") linked to yes branch of node " +
			existingNodeID);
		currentNode.noBranch = new BinTree(newNodeID,newQuestAns);
		}		
    	    return(true);
	    }
	else {
	    // Try yes branch if it exists
	    if (currentNode.yesBranch != null) { 	
	        if (searchTreeAndAddNoNode(currentNode.yesBranch,
		        	existingNodeID,newNodeID,newQuestAns)) {    	
	            return(true);
		    }	
		else {
    	        // Try no branch if it exists
	    	    if (currentNode.noBranch != null) {
    	    		return(searchTreeAndAddNoNode(currentNode.noBranch,
				existingNodeID,newNodeID,newQuestAns));
			}
		    else return(false);	// Not found here
		    }
		 }
	    else return(false);	// Not found here
	    }
   	} 	

    /* --------------------------------------------- */
    /*                                               */
    /*               TREE QUERY METHODS             */
    /*                                               */
    /* --------------------------------------------- */

    public void queryBinTree() throws IOException {
        queryBinTree(rootNode);
        }

    private void queryBinTree(BinTree currentNode) throws IOException {

        // Test for leaf node (answer) and missing branches

        if (currentNode.yesBranch==null) {
            if (currentNode.noBranch==null) System.out.println(currentNode.questOrAns);
            else System.out.println("Error: Missing \"Yes\" branch at \"" +
            		currentNode.questOrAns + "\" question");
            return;
            }
        if (currentNode.noBranch==null) {
            System.out.println("Error: Missing \"No\" branch at \"" +
            		currentNode.questOrAns + "\" question");
            return;
            }

        // Question

        askQuestion(currentNode);
        }

    private void askQuestion(BinTree currentNode) throws IOException {
        System.out.println(currentNode.questOrAns + " (enter \"Yes\" or \"No\")");
        String answer = keyboardInput.readLine();
        if (answer.equals("Yes")) queryBinTree(currentNode.yesBranch);
        else {
            if (answer.equals("No")) queryBinTree(currentNode.noBranch);
            else {
                System.out.println("ERROR: Must answer \"Yes\" or \"No\"");
                askQuestion(currentNode);
                }
            }
        }

    /* ----------------------------------------------- */
    /*                                                 */
    /*               TREE OUTPUT METHODS               */
    /*                                                 */
    /* ----------------------------------------------- */

    /* OUTPUT BIN TREE */

    public void outputBinTree() {

        outputBinTree("1",rootNode);
        }

    private void outputBinTree(String tag, BinTree currentNode) {

        // Check for empty node

        if (currentNode == null) return;

        // Output

        System.out.println("[" + tag + "] nodeID = " + currentNode.nodeID +
        		", question/answer = " + currentNode.questOrAns);
        		
        // Go down yes branch

        outputBinTree(tag + ".1",currentNode.yesBranch);

        // Go down no branch

        outputBinTree(tag + ".2",currentNode.noBranch);
	}      		
    }

Table 1: Generic binary decision tree generator and query class

Table 2 shows an application class that makes use of the code presented in Table 1. Here we create a decision tree to identify four different kinds of animal --- Tigers, Leopards, Zebras and Horses.

// DECISION TREE  APPLICATION
// Frans Coenen
// Thursday 15 August 2002
// Department of Computer Science, University of Liverpool

import java.io.*;

class DecisionTreeApp {

    /* ------------------------------- */
    /*                                 */
    /*              FIELDS             */
    /*                                 */
    /* ------------------------------- */

    static BufferedReader keyboardInput = new
                           BufferedReader(new InputStreamReader(System.in));
    static DecisionTree newTree;

    /* --------------------------------- */
    /*                                   */
    /*               METHODS             */
    /*                                   */
    /* --------------------------------- */

    /* MAIN */

    public static void main(String[] args) throws IOException {

        // Create instance of class DecisionTree

        newTree = new DecisionTree();

        // Generate tree

        generateTree();

        // Output tree

        System.out.println("\nOUTPUT DECISION TREE");
        System.out.println("====================");
        newTree.outputBinTree();

        // Query tree

        queryTree();
        }

    /* GENERATE TREE */

    static void generateTree() {
        System.out.println("\nGENERATE DECISION TREE");
        System.out.println("======================");
        newTree.createRoot(1,"Does animal eat meat?");
        newTree.addYesNode(1,2,"Does animal have stripes?");
        newTree.addNoNode(1,3,"Does animal have stripes?");
        newTree.addYesNode(2,4,"Animal is a Tiger");
        newTree.addNoNode(2,5,"Animal is a Leopard");
        newTree.addYesNode(3,6,"Animal is a Zebra");
        newTree.addNoNode(3,7,"Animal is a Horse");
        }

    /* QUERY TREE */
	
    static void queryTree() throws IOException {
        System.out.println("\nQUERY DECISION TREE");
        System.out.println("===================");
        newTree.queryBinTree();

        // Option to exit

        optionToExit();
        }

    /* OPTION TO EXIT PROGRAM */

    static void optionToExit() throws IOException {
        System.out.println("Exit? (enter \"Yes\" or \"No\")");
        String answer = keyboardInput.readLine();
        if (answer.equals("Yes")) return;
        else {
            if (answer.equals("No")) queryTree();
            else {
                System.out.println("ERROR: Must answer \"Yes\" or \"No\"");
                optionToExit();
                }
            }
        }
    }

Table 2: Application class for the above

Some Sample output is presented in Table 3.

$ java DecisionTreeApp

GENERATE DECISION TREE
======================
Created root node 1
Added node 2 onto "yes" branch of node 1
Added node 3 onto "no" branch of node 1
Added node 4 onto "yes" branch of node 2
Added node 5 onto "no" branch of node 2
Added node 6 onto "yes" branch of node 3
Added node 7 onto "no" branch of node 3

OUTPUT DECISION TREE
====================
[1] nodeID = 1, question/answer = Does animal eat meat?
[1.1] nodeID = 2, question/answer = Does animal have stripes?
[1.1.1] nodeID = 4, question/answer = Animal is a Tiger
[1.1.2] nodeID = 5, question/answer = Animal is a Leopard
[1.2] nodeID = 3, question/answer = Does animal have stripes?
[1.2.1] nodeID = 6, question/answer = Animal is a Zebra
[1.2.2] nodeID = 7, question/answer = Animal is a Horse

QUERY DECISION TREE
===================
Does animal eat meat? (enter "Yes" or "No")
Yes
Does animal have stripes? (enter "Yes" or "No")
Yes
Animal is a Tiger
Exit? (enter "Yes" or "No")
No

QUERY DECISION TREE
===================
Does animal eat meat? (enter "Yes" or "No")
No
Does animal have stripes? (enter "Yes" or "No")
No
Animal is a Horse
Exit? (enter "Yes" or "No")
Yes

Table 3: Some sample output

Note that there are many improvements that can be made to the above; for example instead of generating the tree on each invocation we could read the tree from a file. However, the above does suffice to indicate the general idea. Similarly there is no error checking to ensure that the tree, one constructed, is properly "balanced", i.e. the root and body nodes all have exactly two child nodes.




Created and maintained by Frans Coenen. Last updated 05 December 2002