root/ggpa/QNeuralNetwork.cpp

Revision 1, 4.7 kB (checked in by pantley2, 4 years ago)

GGPA code from the good old days of SIGART

Line 
1 /*
2  * QNeuralNetwork.cpp
3  * Contains neural network wrapper.
4  */
5
6 #include "QNeuralNetwork.h"
7 //#include "fann/fann.h"
8 //#include "fann/fixedfann.h"
9 #include <iostream>
10
11 using namespace std;
12
13 QNeuralNetwork::QNeuralNetwork() : NUM_HIDDEN_NEURONS(4)
14 {
15     _ann = NULL;
16 }
17
18 /*
19  * QNeuralNetwork
20  * Parameters:   actionLength - max length on action.toInputs()
21  *               stateLength - max length on state.toInputs()
22  *
23  * Constructor. Creates a neural network.
24  */
25
26 QNeuralNetwork QNeuralNetwork::operator=(const QNeuralNetwork &theNetwork)
27 {
28     QNeuralNetwork tempNetwork;
29     tempNetwork._ann = theNetwork._ann;
30     tempNetwork._actionLength = theNetwork._actionLength;
31     tempNetwork._stateLength = theNetwork._stateLength;
32     return tempNetwork;
33 }
34
35 QNeuralNetwork::QNeuralNetwork(int actionLength, int stateLength) : NUM_HIDDEN_NEURONS(4)
36 {
37     _actionLength = actionLength;
38     _stateLength = stateLength;
39     _ann = fann_create_standard(3, actionLength+stateLength, NUM_HIDDEN_NEURONS, 1);
40     // Should use a less arbitrary network structure at some point.
41
42     //  Values chosen randomly
43     //  Are all of these functions necessary?
44     //fann_randomize_weights(_ann, 0, 1);
45     fann_set_training_algorithm(_ann, FANN_TRAIN_INCREMENTAL);
46     fann_set_learning_rate( _ann, .2);
47     fann_set_learning_momentum(_ann, .3);
48     //fann_set_activation_function_hidden(_ann, FANN_SIGMOID_SYMMETRIC);
49     //fann_set_activation_function_output(_ann, FANN_SIGMOID_SYMMETRIC);
50     fann_set_activation_steepness_hidden(_ann, .7);
51     fann_set_activation_steepness_output(_ann, .5);
52    
53 }
54
55 QNeuralNetwork::~QNeuralNetwork()
56 {
57     fann_destroy(_ann);
58 }
59
60 /*
61  * update
62  * Parameters:   state - a State object that represents a state in the game
63  *               action - an Action object that represents an action in the
64  *                        game
65  *               utility - a double that represents the utility to insert
66  *                         for this State/Action pair
67  * No return value
68  *
69  * Updates the entry for the given State/Action pair.
70  */
71 void QNeuralNetwork::update(State &state, Action &action, double utility)
72 {
73     vector<double> input_vec = state.toInputs();
74     vector<double> action_vec = action.toInputs();
75
76     int initSize = input_vec.size();
77     for(int i = initSize; i < _stateLength; i++)
78        input_vec.push_back(0.0);
79    
80     input_vec.insert(input_vec.end(), action_vec.begin(), action_vec.end());
81
82     initSize = action_vec.size();
83     for(int i = initSize; i < _actionLength; i++)
84        input_vec.push_back(0.0);
85    
86    
87     // WARNING: this is a horrible, horrible hack. Somebody less lazy than me
88     // ought to rewrite it eventually.
89     fann_train(_ann, (fann_type*) &(input_vec)[0], (fann_type*) &utility);
90 }
91
92 /*
93  * getValue
94  * Parameters:   state - a State object that represents a state in the game
95  *               action - an Action object that represents an action in the
96  *                        game
97  * Return value: a double representing the utility of the given State/
98  *   Action pair.
99  * This function will return the current utility of the given State/Action
100  *   pair.  In other words, it will return a value representing how good
101  *   the action is from the state.
102  */
103 double QNeuralNetwork::getValue(State &state, Action &action)
104 {
105     fann_type *output;
106     vector<double> input_vec = state.toInputs();
107     vector<double> action_vec = action.toInputs();
108
109     int initSize = input_vec.size();
110     for(int i = initSize; i < _stateLength; i++)
111        input_vec.push_back(0.0);
112
113     input_vec.insert(input_vec.end(), action_vec.begin(), action_vec.end());
114    
115     initSize = action_vec.size();
116     for(int i = initSize; i < _actionLength; i++)
117        input_vec.push_back(0.0);
118
119     // WARNING: this is a horrible, horrible hack. Somebody less lazy than me
120     // ought to rewrite it eventually.
121     input_vec.insert(input_vec.end(), action_vec.begin(), action_vec.end());
122     output = fann_run(_ann, (fann_type*) &(input_vec)[0]);
123     return output[0];
124 }
125
126 /*
127  * writeToFile
128  *
129  * Parameters:   file - String containing the name of the file to write to.
130  * Return value: none
131  *
132  * Writes the QNeuralNetwork to the given file.  If the file does not exist, it
133  *   will be created.  If it does exist, then it will be written to the
134  *   end of the file
135  */
136 void QNeuralNetwork::writeToFile(string filename)
137 {
138     fann_save(_ann, &(filename[0]));   
139 }
140
141 /*
142  * readFromFile
143  *
144  * Parameters:   file - string containing the name of the file to read from
145  * Return value: none
146  *
147  * Reads in a QNeuralNetwork from a file.  The current QNeuralNetwork will be set to
148  *    be this saved QNeuralNetwork
149  */
150 void QNeuralNetwork::readFromFile(string filename)
151 {
152     _ann = fann_create_from_file(&(filename[0]));
153 }
Note: See TracBrowser for help on using the browser.