Commit f9a816aa authored by Vikram Waradpande's avatar Vikram Waradpande
Browse files

Add conv2d layer

parent 79c2ad36
......@@ -64,7 +64,7 @@ START = 0.5
#Run params
START_STATE = 1
MAX_EPISODES = 60
EMBTOGGLE = 1
EMBTOGGLE = 2
DIMENSION = 20
TARGET_LOC = 399
EMBEDPATH = "./Embeddings/"
......@@ -133,12 +133,10 @@ class DQNAgent:
model = Sequential()
if(EMBTOGGLE == 2):
model.add(Conv2D(8, kernel_size=(3, 3), strides=(1, 1),
activation='relu',
input_shape= (GRID, GRID,1)))
model.add(Conv2D(16, kernel_size=(3, 3), strides=(1, 1),
activation='relu',input_shape= (GRID, GRID,1)))
model.add(MaxPooling2D(pool_size=(3, 3), strides=(1, 1)))
model.add(Conv2D(16, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(3, 3)))
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu'))
model.add(Flatten())
model.add(Dense(50, activation='relu'))
model.add(Dense(4, activation='tanh'))
......@@ -146,7 +144,7 @@ class DQNAgent:
optimizer='adam')
return model
if(EMBTOGGLE == 0):
elif(EMBTOGGLE == 0):
model.add(Dense(64, input_shape=(
self.state_size,), activation='relu'))
else:
......@@ -171,7 +169,7 @@ class DQNAgent:
elif(EMBTOGGLE == 2):
inputs = np.zeros((batch_size, GRID, GRID, 1))
else:
inputs = np.zeros((batch_size, DIMENSION))
inputs = np.zeros((batch_size, DIMENSION))
#inputs = np.zeros((batch_size, 2*DIMENSION))
targets = np.zeros((batch_size, self.num_actions))
i = 0
......@@ -500,8 +498,8 @@ def deepQLearning(model, env, state, args, randomMode=False, **opt):
if(EMBTOGGLE == 1):
model.remember(cs, action, reward, ns, game_over)
elif(EMBTOGGLE == 2):
current_state = np.reshape(current_state, (1,10,10,))
next_state = np.reshape(next_state, (1,10,10,))
current_state = np.reshape(current_state, (1,GRID,GRID,))
next_state = np.reshape(next_state, (1,GRID,GRID,))
current_state = np.expand_dims(current_state,-1)
next_state = np.expand_dims(next_state,-1)
......@@ -531,10 +529,6 @@ def deepQLearning(model, env, state, args, randomMode=False, **opt):
saved_env = env
print("Reached 100%% win rate at episode: %d" % (episode,))
if save_file_path:
print("Saved model to file :{}".format(save_file_path))
model.save(save_file_path)
memory.sort(key=len)
memory = np.array(memory)
break
......@@ -663,7 +657,7 @@ def trainDQN(args):
state_index = -1
for _ in range(int(args.iterations)):
state = 18
state = 5
if state in obstacles_loc or state == TARGET_LOC:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment