diff --git a/MapTable.lua b/MapTable.lua index c79f1ea1d..0eda8cb3a 100644 --- a/MapTable.lua +++ b/MapTable.lua @@ -3,14 +3,14 @@ local MapTable, parent = torch.class('nn.MapTable', 'nn.Container') function MapTable:__init(module, shared) parent.__init(self) self.shared = (shared == nil) and true or shared - self.sharedparams = {'weight', 'bias', 'gradWeight', 'gradBias'} + self.sharedparams = {'weight', 'bias', 'gradWeight', 'gradBias', 'running_mean', 'runnig_var', 'save_mean', 'save_var'} self.output = {} self.gradInput = {} self:add(module) end function MapTable:_extend(n) - self.sharedparams = self.sharedparams or {'weight', 'bias', 'gradWeight', 'gradBias'} + self.sharedparams = self.sharedparams or {'weight', 'bias', 'gradWeight', 'gradBias', 'running_mean', 'runnig_var', 'save_mean', 'save_var'} self.modules[1] = self.module for i = 2, n do if not self.modules[i] then