Skip to content

Commit

Permalink
TimesNode should allow broadcasting when using element_times with red…
Browse files Browse the repository at this point in the history
…uce_sum for dot product

Update ResNet readme for Cha's CR
  • Loading branch information
KeDengMS committed Jun 7, 2017
1 parent 782d49a commit fca6f27
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
10 changes: 5 additions & 5 deletions Examples/Image/Classification/ResNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,31 @@ We offer multiple ResNet examples, including ResNet20 and ResNet110 for CIFAR-10

### CNTK Pre-trained
Models pre-trained with CNTK scripts.
#### ResNet18
#### ResNet18 for ImageNet 1K
|CNTK model download path | https://www.cntk.ai/Models/CNTK_Pretrained/ResNet18_ImageNet_CNTK.model
|:-------|:---
|Training script | [ResNet18_ImageNet1K.cntk](./BrainScript/ResNet18_ImageNet1K.cntk)
|Single crop top 1 / top 5 error | 29.186% / 10.284%

#### ResNet34
#### ResNet34 for ImageNet 1K
|CNTK model download path | https://www.cntk.ai/Models/CNTK_Pretrained/ResNet34_ImageNet_CNTK.model
|:---------|:---
|Training script | [ResNet34_ImageNet1K.cntk](./BrainScript/ResNet34_ImageNet1K.cntk)
|Single crop top 1 / top 5 error | 27.240% / 8.962%

#### ResNet50
#### ResNet50 for ImageNet 1K
|CNTK model download path | https://www.cntk.ai/Models/CNTK_Pretrained/ResNet50_ImageNet_CNTK.model
|:---------|:---
|Training script | [ResNet50_ImageNet1K.cntk](./BrainScript/ResNet50_ImageNet1K.cntk)
|Single crop top 1 / top 5 error | 23.862% / 7.180%

#### ResNet20
#### ResNet20 for CIFAR-10
|CNTK model download path | https://www.cntk.ai/Models/CNTK_Pretrained/ResNet20_CIFAR10_Python.model
|:-------|:---
|Training script | [TrainResNet_CIFAR10.py --network resnet20](./Python/TrainResNet_CIFAR10.py)
|Single crop top 1 error | 8.17%

#### ResNet110
#### ResNet110 for CIFAR-10
|CNTK model download path | https://www.cntk.ai/Models/CNTK_Pretrained/ResNet110_CIFAR10_Python.model
|:-------|:---
|Training script | [TrainResNet_CIFAR10.py --network resnet110](./Python/TrainResNet_CIFAR10.py)
Expand Down
2 changes: 1 addition & 1 deletion Source/ComputationNetworkLib/LinearAlgebraNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ class TimesNodeBase : public ComputationNode<ElemType>, public NumInputs<2>
}
else
{
ElementTimesNode<ElemType>::ForwardPropImpl(*this, fr, false/*allowBroadcast*/);
ElementTimesNode<ElemType>::ForwardPropImpl(*this, fr, true/*allowBroadcast*/);
}
return;
}
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/cntk/ops/tests/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,9 @@ def test_per_dim_mean_var_norm():
result = func.eval({x : np.asarray([[3.], [1.]], dtype=np.float32)})
assert np.array_equal(result, [[.5], [-.5]])

def test_times_const_broadcast():
x = C.input_variable((3,))
a = C.constant(np.ones((3,), dtype=np.float32))
y = C.times_transpose(a, x)
result = y.eval({x:np.asarray([[1,2,3],[1,2,3]], dtype=np.float32)})
assert np.array_equal(result, [[6], [6]])

0 comments on commit fca6f27

Please sign in to comment.