| 1 |
|
|---|
| 2 |
|
|---|
| 3 |
|
|---|
| 4 |
|
|---|
| 5 |
#include "TestTools.h" |
|---|
| 6 |
#include "QNeuralNetwork.h" |
|---|
| 7 |
#include "State.h" |
|---|
| 8 |
#include "Action.h" |
|---|
| 9 |
|
|---|
| 10 |
|
|---|
| 11 |
#include <iostream> |
|---|
| 12 |
using namespace std; |
|---|
| 13 |
|
|---|
| 14 |
int main() |
|---|
| 15 |
{ |
|---|
| 16 |
int stateLength = 0; |
|---|
| 17 |
State state0(2); |
|---|
| 18 |
State state1(2); |
|---|
| 19 |
State state2(2); |
|---|
| 20 |
State state3(2); |
|---|
| 21 |
{ |
|---|
| 22 |
SentenceStructure* structure = new SentenceStructure("foo", 1, false); |
|---|
| 23 |
structure->addValue(0, "1"); |
|---|
| 24 |
structure->addValue(0, "2"); |
|---|
| 25 |
structure->setOffset(0); |
|---|
| 26 |
|
|---|
| 27 |
vector<string> values0(1, "1"); |
|---|
| 28 |
vector<string> values1(1, "2"); |
|---|
| 29 |
|
|---|
| 30 |
Sentence sentence0(values0, structure); |
|---|
| 31 |
Sentence sentence1(values1, structure); |
|---|
| 32 |
|
|---|
| 33 |
|
|---|
| 34 |
state1.addSentence(&sentence0); |
|---|
| 35 |
|
|---|
| 36 |
state2.addSentence(&sentence1); |
|---|
| 37 |
|
|---|
| 38 |
state3.addSentence(&sentence0); |
|---|
| 39 |
state3.addSentence(&sentence1); |
|---|
| 40 |
|
|---|
| 41 |
int l1, l2, l3; |
|---|
| 42 |
l1 = state1.toInputs().size(); |
|---|
| 43 |
l2 = state2.toInputs().size(); |
|---|
| 44 |
l3 = state3.toInputs().size(); |
|---|
| 45 |
stateLength = l1; |
|---|
| 46 |
if(l2 > stateLength) |
|---|
| 47 |
stateLength = l2; |
|---|
| 48 |
if(l3 > stateLength) |
|---|
| 49 |
stateLength = l3; |
|---|
| 50 |
} |
|---|
| 51 |
|
|---|
| 52 |
int maxActionLength = 0; |
|---|
| 53 |
Action *action = NULL, *action2 = NULL, *action3 = NULL; |
|---|
| 54 |
{ |
|---|
| 55 |
SentenceStructure *structure = NULL, *structure2 = NULL; |
|---|
| 56 |
vector<string> values, values2, values3; |
|---|
| 57 |
values.push_back("1"); |
|---|
| 58 |
values.push_back("2"); |
|---|
| 59 |
values2.push_back("1"); |
|---|
| 60 |
values2.push_back("2"); |
|---|
| 61 |
values2.push_back("0"); |
|---|
| 62 |
values3.push_back("1"); |
|---|
| 63 |
values3.push_back("2"); |
|---|
| 64 |
values3.push_back("0"); |
|---|
| 65 |
|
|---|
| 66 |
structure = new SentenceStructure("succ", 2, true); |
|---|
| 67 |
structure2= new SentenceStructure("cell", 3, true); |
|---|
| 68 |
structure->setOffset(0); |
|---|
| 69 |
structure->addValue(0, "0"); |
|---|
| 70 |
structure->addValue(0, "1"); |
|---|
| 71 |
structure->addValue(1, "0"); |
|---|
| 72 |
structure->addValue(1, "1"); |
|---|
| 73 |
structure->addValue(1, "2"); |
|---|
| 74 |
|
|---|
| 75 |
structure2->setOffset(6); |
|---|
| 76 |
structure2->addValue(0, "0"); |
|---|
| 77 |
structure2->addValue(0, "1"); |
|---|
| 78 |
structure2->addValue(0, "2"); |
|---|
| 79 |
structure2->addValue(1, "0"); |
|---|
| 80 |
structure2->addValue(1, "1"); |
|---|
| 81 |
structure2->addValue(1, "2"); |
|---|
| 82 |
structure2->addValue(2, "0"); |
|---|
| 83 |
structure2->addValue(2, "1"); |
|---|
| 84 |
structure2->addValue(2, "2"); |
|---|
| 85 |
|
|---|
| 86 |
vector<string> temp; |
|---|
| 87 |
temp.push_back(""); |
|---|
| 88 |
|
|---|
| 89 |
action = new Action(values, structure, "", temp); |
|---|
| 90 |
action2 = new Action(values2, structure2, "", temp); |
|---|
| 91 |
action3 = new Action(values3, structure2, "", temp); |
|---|
| 92 |
|
|---|
| 93 |
int l1, l2, l3; |
|---|
| 94 |
l1 = action->toInputs().size(); |
|---|
| 95 |
l2 = action2->toInputs().size(); |
|---|
| 96 |
l3 = action3->toInputs().size(); |
|---|
| 97 |
|
|---|
| 98 |
maxActionLength = l1; |
|---|
| 99 |
if(l2 > maxActionLength) |
|---|
| 100 |
maxActionLength = l2; |
|---|
| 101 |
if(l3 > maxActionLength) |
|---|
| 102 |
maxActionLength = l3; |
|---|
| 103 |
} |
|---|
| 104 |
|
|---|
| 105 |
QNeuralNetwork network (maxActionLength, stateLength); |
|---|
| 106 |
|
|---|
| 107 |
for(int i=0; i< 100; i++) |
|---|
| 108 |
{ |
|---|
| 109 |
network.update(state0, *action, 1.0); |
|---|
| 110 |
|
|---|
| 111 |
|
|---|
| 112 |
|
|---|
| 113 |
|
|---|
| 114 |
} |
|---|
| 115 |
cout << network.getValue(state0, *action) << endl; |
|---|
| 116 |
cout << network.getValue(state1, *action2) << endl; |
|---|
| 117 |
cout << network.getValue(state2, *action) << endl; |
|---|
| 118 |
cout << network.getValue(state3, *action3) << endl; |
|---|
| 119 |
cout << network.getValue(state1, *action2) << endl; |
|---|
| 120 |
} |
|---|
| 121 |
|
|---|