kiwi.modules.common.distributions
TruncatedNormal
Extension of the Distribution class, which applies a sequence of Transforms
kiwi.modules.common.distributions.
Bases: torch.distributions.TransformedDistribution
torch.distributions.TransformedDistribution
Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:
X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)|
Note that the .event_shape of a TransformedDistribution is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.
.event_shape
TransformedDistribution
An example for the usage of TransformedDistribution would be:
# Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms)
For more examples, please look at the implementations of Gumbel, HalfCauchy, HalfNormal, LogNormal, Pareto, Weibull, RelaxedBernoulli and RelaxedOneHotCategorical
Gumbel
HalfCauchy
HalfNormal
LogNormal
Pareto
Weibull
RelaxedBernoulli
RelaxedOneHotCategorical
arg_constraints
support
has_rsample
partition_function
scale
mean
\(pdf = f(x; \mu, \sigma, a, b) = \frac{\phi(\xi)}{\sigma Z}\)
\(\xi=\frac{x-\mu}{\sigma}\)
\(\alpha=\frac{a-\mu}{\sigma}\)
\(\beta=\frac{b-\mu}{\sigma}\)
\(Z=\Phi(\beta)-\Phi(\alpha)\)
\(\mu + \frac{\phi(\alpha)-\phi(\beta)}{Z}\sigma\)
variance
Returns the variance of the distribution.
log_prob
Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.
cdf
Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.
icdf
Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.
kiwi.modules.common.attention
kiwi.modules.common.feedforward