root/ggpa/QLearner.cpp

Revision 1, 7.2 kB (checked in by pantley2, 4 years ago)

GGPA code from the good old days of SIGART

Line 
1 /*
2  * QLearner.cpp
3  * Class that will handle the QLearning algorithm
4  */
5
6
7
8 #include "QLearner.h"
9
10 #include <string>
11 #include <iostream>
12 #include "GameWorld.h"
13 #include "QFunction.h"
14 #include "QNeuralNetwork.h"
15 #include "State.h"
16 #include "Action.h"
17
18 #define GAMMA 0.3
19
20
21 /*
22  * Initializes the QLearner to learn on the given world.
23  */
24 QLearner::QLearner(GameWorld* world) {
25     gameWorld = world;
26     qFunction = new QNeuralNetwork(gameWorld->getState()->getSize(), 5);
27 }
28
29 /*
30  * learn
31  * Parameters: learnTime - the amount of time allowed for learning.
32  *                         by default it will use the time dictated
33  *                         in the GameWorld class.
34  * Return type: none.
35  * Begins training on the given game.
36  */
37 void QLearner::learn(double learnTime) {
38     if (learnTime == DEFAULT_LEARN_TIME) {
39         learnTime = gameWorld->getStartClock();
40     }
41
42     gameWorld->startGame();
43
44     // NOTE: timing measure is temporary!!!!!
45     for (int count = 0; count < learnTime; count++) {
46         // perform a learning iteration.
47         if (gameWorld->isDone()) {
48             gameWorld->startGame();
49         }
50
51         State* state = gameWorld->getState();
52
53         vector<Action*> legalMoves = gameWorld->getLegalMoves();
54         Action* action = chooseMove(legalMoves);
55
56         // perform the action in the GameWorld
57         performAction(action);
58
59         State* newState = gameWorld->getState();
60         vector<Action*> newActions = gameWorld->getLegalMoves();
61        
62         int reward = gameWorld->getReward(gameWorld->getRole());
63         double nextUtility = getMaxUtility(newState, newActions);
64
65         // update the utility
66         updateUtility(state, action, reward, nextUtility);
67     }
68    
69 }
70
71 /*
72  * getAction
73  * Parameters: decisionTime - the amount of time allowed for making the
74  *                            decision.  By default, the time stored
75  *                            in the GameWorld class is used.
76  * Return type: Action*
77  * Returns an action for the current state
78  */
79 Action* QLearner::getAction(double decisionTime) {
80     // get the current state
81     State* state = gameWorld->getState();
82
83     // get the possible actions
84     vector<Action*> actions = gameWorld->getLegalMoves();
85
86     // find the action with the highest utility
87     double maxUtility = 0;
88     Action* bestAction = NULL;
89     for (unsigned int i = 0; i < actions.size(); i++) {
90         double utility = qFunction->getValue(*state, *(actions[i]));
91         if (bestAction == NULL || utility > maxUtility) {
92             maxUtility = utility;
93             bestAction = actions[i];
94         }
95     }
96
97     return bestAction;
98 }
99
100 /**
101  * performAction
102  * Parameters: action - the action to perform
103  * Return type: none
104  * Performs the action in the game world, handling any extra overhead
105  */
106 void QLearner::performAction(Action* action) {
107     vector<string> roles = gameWorld->getRoles();
108     vector<Action*> actions(roles.size());
109    
110     for (unsigned int roleNum = 0; roleNum < roles.size(); roleNum++) {
111         if (roles[roleNum] == gameWorld->getRole()) {
112             actions[roleNum] = action;
113         } else {
114             vector<Action*> possibleActions
115                 = gameWorld->getLegalMoves(roles[roleNum]);
116             actions[roleNum] = getRandomMove(possibleActions);
117         }
118     }
119
120     // call the update
121     gameWorld->update(actions);
122 }
123
124 /**
125  * getMaxUtility
126  * Parameters: state - the current state
127  *             actions - the set of actions we are interested in
128  * Return type: double
129  * Returns the maximum value for the QFunction over all of the actions.
130  */
131 double QLearner::getMaxUtility(State* state, vector<Action*> actions) {
132     double max = 0;
133     for (unsigned int actionIndex = 0; actionIndex < actions.size();
134          actionIndex++) {
135         Action* action = actions[actionIndex];
136         double utility = qFunction->getValue(*state, *action);
137         if (utility > max) {
138             max = utility;
139         }
140     }
141     return max;
142 }
143
144 /*
145  * updateUtility
146  * Parameters: state - the current state
147  *             action - the current action
148  *             reward - the immediate reward
149  *             nextUtility - the expected utility of the resulting state
150  * No return type
151  * Updates the utility for the given state-action pair to a value
152  * determined by reward and nextUtility.
153  */
154 void QLearner::updateUtility(State* state, Action* action, int reward,
155                    double nextUtility) {
156     double utility = reward + GAMMA * nextUtility;
157     qFunction->update(*state, *action, utility);
158 }
159
160
161
162 /*
163  * chooseMove
164  * Parameters: possibleActions - a vector of actions to choose from
165  * Return type: Action*
166  * Returns one of the actions from the vector.  This function is intended
167  * for use with the primary role during the learning process.  Using it
168  * in other conditions may give unexpected results.
169  */
170 Action* QLearner::chooseMove(vector<Action*> possibleActions) const {
171     // initialize the weights vector
172     vector<double> weights(possibleActions.size());
173     for (unsigned int i = 0; i < possibleActions.size(); i++) {
174         weights[i] = qFunction->getValue(*gameWorld->getState(),
175                                         *possibleActions[i]);
176     }
177
178     // randomly choose an action according to these weights
179     return getRandomMove(possibleActions, weights);
180 }
181
182 /*
183  * getRandomMove
184  * Parameters: possibleActions - a vector of actions to choose from
185  * Return type: Action*
186  * Randomly selects one of the actions from the vector.  Random seeding
187  * is not handled in this function.
188  */
189 Action* QLearner::getRandomMove(vector<Action*> possibleActions) const {
190     // initialize a weights vector
191     vector<double> weights(possibleActions.size());
192     for (unsigned int num = 0; num < possibleActions.size(); num++) {
193         weights[num] = 1.0;
194     }
195    
196     // call the other getRandomMove function
197     return getRandomMove(possibleActions, weights);
198 }
199
200 /*
201  * getRandomMove
202  * Parameters: possibleActions - a vector of actions to choose from
203  *             weights - a vector of double representing the weights
204  *                       for each of the actions.
205  * Return type: Action*
206  * Randomly selects one of the actions from the vector based on the
207  * distribution defined by weights.  This does not need to be normalized.
208  */
209 Action* QLearner::getRandomMove(vector<Action*> possibleActions,
210                                 vector<double> weights) const {
211     // determine the number of actions to be considered
212     unsigned int maxIndex = possibleActions.size();
213     if (weights.size() < possibleActions.size()) {
214         maxIndex = weights.size();
215     }
216
217     if (maxIndex == 0) {
218         cerr << "Empty set of actions" << endl;
219         exit(1);
220     }
221
222     // sum up the weights
223     double sum = 0;
224     for (unsigned int i = 0; i < maxIndex; i++) {
225         sum += weights[i];
226     }
227
228     // get the random number
229     double randomNumber = sum * (rand() / (RAND_MAX + 1.0));
230    
231     // find the corresponding action
232     double curMax = 0;
233     for (unsigned int i = 0; i < maxIndex; i++) {
234         curMax += weights[i];
235         if (randomNumber <= curMax) {
236             return possibleActions[i];
237         }
238     }
239
240     // should never get here
241     cerr << "No action was selected! curMax = " << curMax << " rand = "
242          << randomNumber << endl;
243     exit(1);
244    
245 }
Note: See TracBrowser for help on using the browser.