Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update split_lod_tensor, create_array and array_length doc #11383

Merged
merged 28 commits into from
Jun 17, 2018
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8e19c32
update split_lod_tensor, create_array and array_length doc
jacquesqiao Jun 12, 2018
2c1e2ca
update document
jacquesqiao Jun 12, 2018
4d0fd7e
add API reference for create_tensor
jacquesqiao Jun 12, 2018
d824229
add doc for batch norm
jacquesqiao Jun 12, 2018
b645dfa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 12, 2018
f3e631c
small update
jacquesqiao Jun 12, 2018
e72eb0e
small update
jacquesqiao Jun 12, 2018
dde0a28
add doc for Switch
jacquesqiao Jun 12, 2018
d76f8a8
refine doc of polynomial_decay
jacquesqiao Jun 12, 2018
fd9b650
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 13, 2018
76129f0
update comment
jacquesqiao Jun 13, 2018
0ae6709
update document
jacquesqiao Jun 14, 2018
21ecd35
little optimize
jacquesqiao Jun 14, 2018
62bf672
update document for Switch
jacquesqiao Jun 14, 2018
2f9ed97
follow comment
jacquesqiao Jun 14, 2018
3a25cee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 15, 2018
9de779f
update switch class
jacquesqiao Jun 15, 2018
e2783bb
update split_lod_tensor doc
jacquesqiao Jun 15, 2018
1c9fc65
update
jacquesqiao Jun 15, 2018
6ace04f
update
jacquesqiao Jun 15, 2018
5b50307
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 15, 2018
8f59d79
update doc for sigmoid_cross_entropy_with_logits
jacquesqiao Jun 15, 2018
bf3ff5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 16, 2018
a4ee0d0
add reverse
jacquesqiao Jun 16, 2018
d1a8498
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 17, 2018
82a4cf1
update image_resize_short and shape doc
jacquesqiao Jun 17, 2018
b77c886
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jun 17, 2018
46ae1c9
add doc for softmax
jacquesqiao Jun 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/operators/detection/polygon_box_transform_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ class PolygonBoxTransformOpMaker : public framework::OpProtoAndCheckerMaker {

AddComment(R"DOC(
PolygonBoxTransform Operator.

PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate.

The input is the final geometry output in detection network.
We use 2*n numbers to denote the coordinate shift from n corner vertices of
the polygon_box to the pixel location. As each distance offset contains two numbers (xi, yi),
the geometry output contains 2*n channels.
PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate.
)DOC");
}
};
Expand Down
73 changes: 55 additions & 18 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,17 @@

def split_lod_tensor(input, mask, level=0):
"""
**split_lod_tensor**

This function takes in an input that contains the complete lod information,
and takes in a mask which is used to mask certain parts of the input.
The output is the true branch and the false branch with the mask applied to
the input at a certain level in the tensor.
the input at a certain level in the tensor. Mainly used in IfElse to split
data into two parts.

Args:
input(tuple|list|None): The input tensor that contains complete
lod information needed to construct the output.
mask(list): A bool column vector which masks the input.
level(int): The specific lod level to rank.
level(int): The specific lod level to split.

Returns:
Variable: The true branch of tensor as per the mask applied to input.
Expand All @@ -75,14 +74,15 @@ def split_lod_tensor(input, mask, level=0):
Examples:
.. code-block:: python

x = layers.data(name='x', shape=[1])
x = fluid.layers.data(name='x', shape=[1])
x.persistable = True

y = layers.data(name='y', shape=[1])
y = fluid.layers.data(name='y', shape=[1])
y.persistable = True

out_true, out_false = layers.split_lod_tensor(
out_true, out_false = fluid.layers.split_lod_tensor(
input=x, mask=y, level=level)

"""
helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_tmp_variable(dtype=input.dtype)
Expand All @@ -105,16 +105,17 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0):

This function takes in an input :math:`x`, the True branch, the False
branch and a binary :math:`mask`. Using this information, this function
merges the True and False branches of the tensor into a single Output
at a certain lod level indiacted by :math:`level`.
merges the True and False branches of the tensor into a single tensor as
output at a certain lod level indicated by :math:`level`. Used in IfElse
to merge the output if True block and False Block.

Args:
in_true(tuple|list|None): The True branch to be merged.
in_false(tuple|list|None): The False branch to be merged.
x(tuple|list|None): The input tensor that contains complete
lod information needed to construct the output.
mask(list): A bool column vector which masks the input.
level(int): The specific lod level to rank.
level(int): The specific lod level to merge.

Returns:
Variable: The merged output tensor.
Expand Down Expand Up @@ -887,14 +888,17 @@ def array_write(x, i, array=None):


def create_array(dtype):
"""This function creates an array of type :math:`LOD_TENSOR_ARRAY` using the
LayerHelper.
"""
**Create LoDTensorArray**

This function creates an array of LOD_TENSOR_ARRAY . It is mainly used to
implement RNN with array_write, array_read and While.

Args:
dtype (int|float): The data type of the elements in the array.
dtype (int|float): The data type of the elements in the lod_tensor_array.

Returns:
Variable: The tensor variable storing the elements of data type.
Variable: The lod_tensor_array variable storing the elements of data type.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -983,7 +987,8 @@ def array_read(array, i):
Returns:
Variable: The tensor type variable that has the data written to it.
Examples:
.. code-block::python
.. code-block:: python

tmp = fluid.layers.zeros(shape=[10], dtype='int32')
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
arr = layers.array_read(tmp, i=i)
Expand Down Expand Up @@ -1020,9 +1025,14 @@ def shrink_memory(x, i, table):


def array_length(array):
"""This function performs the operation to find the length of the input
"""
**Get the Length of Input LoDTensorArray**

This function performs the operation to find the length of the input
LOD_TENSOR_ARRAY.

Related API: array_read, array_write, While.

Args:
array (LOD_TENSOR_ARRAY): The input array that will be used
to compute the length.
Expand All @@ -1031,12 +1041,13 @@ def array_length(array):
Variable: The length of the input LoDTensorArray.

Examples:
.. code-block::python
.. code-block:: python

tmp = fluid.layers.zeros(shape=[10], dtype='int32')
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
arr = fluid.layers.array_write(tmp, i=i)
arr_len = fluid.layers.array_length(arr)

"""
helper = LayerHelper('array_length', **locals())
tmp = helper.create_tmp_variable(dtype='int64')
Expand Down Expand Up @@ -1120,6 +1131,31 @@ def complete(self):


class Switch(object):
"""
**Switch Class**

Many programming languages provide `switch` as a generalization of `if-elif-else`.
Switch class works just like a `if-elif-else`.

The Semantics:

1. A `switch` control-flow checks cases one-by-one.

2. The condition of each case is a boolean value, which is a scalar.

3. It runs the first matched case, or the default case if there is one.

4. Once it matches a case, it runs the corresponding branch and only that branch.

Examples:
.. code-block:: python

with fluid.control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
fluid.tensor.assign(input=one_var, output=div_res)

"""

def __init__(self, name=None):
self.helper = LayerHelper('switch', name=name)
self.inside_scope = False
Expand Down Expand Up @@ -1149,7 +1185,8 @@ def case(self, condition):
return ConditionalBlockGuard(cond_block)

def default(self):
"""create a default case for this switch
"""
create a default case for this switch
"""
pre_cond_num = len(self.pre_not_conditions)
if pre_cond_num == 0:
Expand Down
31 changes: 18 additions & 13 deletions python/paddle/fluid/layers/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,27 @@ def polynomial_decay(learning_rate,
end_learning_rate=0.0001,
power=1.0,
cycle=False):
"""Applies polynomial decay to the initial learning rate.
"""
**Polynomial Decay**

Applies polynomial decay to the initial learning rate.

.. code-block:: python

if cycle:
decay_steps = decay_steps * ceil(global_step / decay_steps)
else:
global_step = min(global_step, decay_steps)
decayed_learning_rate = (learning_rate - end_learning_rate) *
(1 - global_step / decay_steps) ^ power + end_learning_rate

>>> if cycle:
>>> decay_steps = decay_steps * ceil(global_step / decay_steps)
>>> else:
>>> global_step = min(global_step, decay_steps)
>>> decayed_learning_rate = (learning_rate - end_learning_rate) *
>>> (1 - global_step / decay_steps) ^ power +
>>> end_learning_rate
Args:
learning_rate: A scalar float32 value or a Variable. This
learning_rate(Variable|float32): A scalar float32 value or a Variable. This
will be the initial learning rate during training
decay_steps: A Python `int32` number.
end_learning_rate: A Python `float` number.
power: A Python `float` number
cycle: Boolean. If set true, decay the learning rate every decay_steps.
decay_steps(int32): A Python `int32` number.
end_learning_rate(float, Default: 0.0001): A Python `float` number.
power(float, Default: 1.0): A Python `float` number
cycle(bool, Default: False): Boolean. If set true, decay the learning rate every decay_steps.

Returns:
The decayed learning rate
Expand Down
68 changes: 50 additions & 18 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,27 +1611,57 @@ def batch_norm(input,
moving_variance_name=None,
do_model_average_for_mean_and_var=False):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
**Batch Normalization Layer**

Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is one of the following:

1. NHWC `[batch, in_height, in_width, in_channels]`

2. NCHW `[batch, in_channels, in_height, in_width]`

Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.

:math:`input` is the input features over a mini-batch.

.. math::

\\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\
\ mini-batch\ mean \\\\
\\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\
\\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift

Args:
input (Variable): the input variable.
act (str): activation type
is_test (bool): whether to run batch_norm as test mode.
momentum (float): momentum
epsilon (float): epsilon, default 1e-05
param_attr (ParamAttr|None): attributes for parameter
bias_attr (ParamAttr|None): attributes for bias
data_layout (str): data layout, default NCHW
in_place (bool): if True, do not create tmp variable
use_mkldnn (bool): ${use_mkldnn_comment}
name (str): The name of this layer. It is optional.
moving_mean_name (str): The name of moving mean variable name, optional.
moving_variance_name (str): The name of moving variance name, optional.
do_model_average_for_mean_and_var (bool):
input(variable): The input variable which is a LoDTensor.
act(string, Default None): Activation type, linear|relu|prelu|...
is_test(bool, Default False): Used for training or training.
momentum(float, Default 0.9):
epsilon(float, Default 1e-05):
param_attr(ParamAttr): The parameter attribute for Parameter `scale`.
bias_attr(ParamAttr): The parameter attribute for Parameter `bias`.
data_layout(string, default NCHW): NCHW|NHWC
in_place(bool, Default False): Make the input and output of batch norm reuse memory.
use_mkldnn(bool, Default false): ${use_mkldnn_comment}
name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically.
moving_mean_name(string, Default None): The name of moving_mean which store the global Mean.
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.

Returns:
Variable: output of batch_norm layer.
Variable: A tensor variable which is the result after applying batch normalization on the input.

Examples:

.. code-block:: python

hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.batch_norm(input=hidden1)
"""
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
Expand Down Expand Up @@ -4069,7 +4099,7 @@ def image_resize(input,
name=None,
resample='BILINEAR'):
"""
Resize a batch of images.
**Resize a Batch of Images**

The input must be a tensor of the shape (num_batches, channels, in_h, in_w),
and the resizing only applies on the last two dimensions(hight and width).
Expand Down Expand Up @@ -4199,6 +4229,8 @@ def image_resize_short(input, out_short_len, resample='BILINEAR'):

def gather(input, index):
"""
**Gather Layer**

Output is obtained by gathering entries of the outer-most dimension
of X indexed by `index` and concatenate them together.

Expand Down
18 changes: 18 additions & 0 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@


def create_tensor(dtype, name=None, persistable=False):
"""
**Create a Tensor**

Args:
dtype (string): 'float32'|'int32'|..., the data type of the
created tensor.
name (string, Default: None): The name of the created tensor, if not set,
the name will be a random unique one.
persistable (bool, Default: False): Set the persistable flag of the create tensor.

Returns:
Variable: The tensor variable storing the created tensor.

Examples:
.. code-block:: python

tensor = fluid.layers.create_tensor(dtype='float32')
"""
helper = LayerHelper("create_tensor", **locals())
return helper.create_variable(
name=helper.name, dtype=dtype, persistable=persistable)
Expand Down