Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Marginal diagnostic tool (#1691)
Browse files Browse the repository at this point in the history
Summary:
This commit includes the marginal 1D diagnostic tool with JavaScript callbacks.

### Motivation

This PR completes one tool that uses Bokeh and JavaScript callbacks in order to create an interactive tool that can be viewed in Jupyter. This refactors the code in PR #1631 heavily, since pure Python callbacks were found to not function properly with internal tools.

### Changes proposed

A new `tool` folder in the `diagnostics` folder contains the proposed changes. In this folder there is a `js` folder that contains all the JavaScript callbacks needed for the Bokeh tool. The tool creates plots of marginal distributions for each random variable of the model. The output is a self-contained HTML object that can be rendered in Jupyter without any external CDN calls for JS resources.

Pull Request resolved: #1691

Test Plan:
Unit tests for the Python and JavaScript will be done at a later commit. Right now the testing was to run the tool in the Coin Flipping tutorial, and to inspect the output and ensure only static resources were used.

### Types of changes

- [ ] Docs change / refactoring / dependency upgrade
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

### Checklist

- [x] My code follows the code style of this project.
- [ ] My change requires a change to the documentation.
- [ ] I have updated the documentation accordingly.
- [x] I have read the **[CONTRIBUTING](https://github.com/facebookresearch/beanmachine/blob/main/CONTRIBUTING.md)** document.
- [ ] I have added tests to cover my changes.
- [ ] All new and existing tests passed.
- [x] The title of my pull request is a short description of the requested changes.

### TODO

- [ ] Python unit tests
- [ ] JavaScript unit tests
- [ ] Figure out if the build should run `npm run build` for the tools, or if we should just have the minified code for the JS callbacks in the code base.

Reviewed By: feynmanliang

Differential Revision: D39714194

Pulled By: horizon-blue

fbshipit-source-id: 4d87a9fb4108c093f327a94ea33d604dcda68dc8
  • Loading branch information
ndmlny-qs authored and facebook-github-bot committed Oct 15, 2022
1 parent edf7424 commit 91bce6a
Show file tree
Hide file tree
Showing 29 changed files with 4,893 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ botorch/qmc/sobol.c*
# Sphinx documentation
sphinx/build/

# Docusaurus
# Docusaurus and diagnostic tools
website/build/
website/i18n/
website/node_modules/
node_modules

# Tutorials
docs/overview/tutorials/*/*.mdx
Expand Down
2 changes: 2 additions & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import experimental
from .diagnostics import Diagnostics
from .diagnostics.common_statistics import effective_sample_size, r_hat, split_r_hat
from .diagnostics.tools import viz
from .inference import (
CompositionalInference,
empirical,
Expand Down Expand Up @@ -58,4 +59,5 @@
"random_variable",
"simulate",
"split_r_hat",
"viz",
]
22 changes: 22 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# flake8: noqa

"""Visual diagnostic tools for Bean Machine models."""

import sys
from pathlib import Path


if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict


TOOLS_DIR = Path(__file__).parent.resolve()
JS_DIR = TOOLS_DIR.joinpath("js")
JS_DIST_DIR = JS_DIR.joinpath("dist")
75 changes: 75 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

const OFF = 0;
const WARNING = 1;
const ERROR = 2;

module.exports = {
root: true,
env: {
browser: true,
commonjs: true,
jest: true,
node: true,
},
parser: '@typescript-eslint/parser',
parserOptions: {
allowImportExportEverywhere: true,
},
extends: ['airbnb', 'prettier', 'plugin:import/typescript'],
plugins: ['prefer-arrow'],
rules: {
// Allow more than 1 class per file.
'max-classes-per-file': ['error', {ignoreExpressions: true, max: 2}],
// Allow snake_case.
camelcase: [
OFF,
{
properties: 'never',
ignoreDestructuring: true,
ignoreImports: true,
ignoreGlobals: true,
},
],
'no-underscore-dangle': OFF,
// Arrow function rules.
'prefer-arrow/prefer-arrow-functions': [
ERROR,
{
disallowPrototype: true,
singleReturnOnly: false,
classPropertiesAllowed: false,
},
],
'prefer-arrow-callback': [ERROR, {allowNamedFunctions: true}],
'arrow-parens': [ERROR, 'always'],
'arrow-body-style': [ERROR, 'always'],
'func-style': [ERROR, 'declaration', {allowArrowFunctions: true}],
'react/function-component-definition': [
ERROR,
{
namedComponents: 'arrow-function',
unnamedComponents: 'arrow-function',
},
],
// Ignore the global require, since some required packages are BrowserOnly.
'global-require': 0,
// We reassign several parameter objects since Bokeh is just updating values in the
// them.
'no-param-reassign': 0,
// Ignore certain webpack alias because it can't be resolved
'import/no-unresolved': [
ERROR,
{ignore: ['^@theme', '^@docusaurus', '^@generated', '^@bokeh']},
],
'import/extensions': OFF,
'object-shorthand': [ERROR, 'never'],
'prefer-destructuring': [WARNING, {object: true, array: true}],
'no-nested-ternary': 0,
},
};
8 changes: 8 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/.prettierrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"arrowParens": "always",
"bracketSpacing": false,
"printWidth": 88,
"proseWrap": "never",
"singleQuote": true,
"trailingComma": "all"
}
55 changes: 55 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"name": "visual-diagnostic-tools",
"version": "0.1.0",
"description": "",
"license": "MIT",
"keywords": [],
"repository": {},
"scripts": {
"build": "webpack"
},
"dependencies": {
"@bokeh/bokehjs": "^2.4.3",
"fast-kde": "^0.2.1"
},
"devDependencies": {
"@types/node": "^18.0.4",
"@typescript-eslint/eslint-plugin": "^5.30.5",
"@typescript-eslint/parser": "^5.30.5",
"eslint": "^8.19.0",
"eslint-config-airbnb": "^19.0.4",
"eslint-config-prettier": "^8.5.0",
"eslint-plugin-import": "^2.26.0",
"eslint-plugin-jsx-a11y": "^6.5.1",
"eslint-plugin-prefer-arrow": "^1.2.3",
"eslint-plugin-react": "^7.28.0",
"eslint-plugin-react-hooks": "^4.3.0",
"prettier": "^2.7.1",
"ts-loader": "^9.3.1",
"ts-node": "^10.9.1",
"typescript": "^4.7.4",
"webpack": "^5.74.0",
"webpack-cli": "^4.10.0"
},
"overrides": {
"cwise": "$cwise",
"minimist": "$minimist",
"quote-stream": "$quote-stream",
"static-eval": "$static-eval",
"static-module": "$static-module",
"typedarray-pool": "$typedarray-pool"
},
"peerDependencies": {
"@types/cwise": "^1.0.4",
"@types/minimist": "^1.2.2",
"@types/static-eval": "^0.2.31",
"@types/typedarray-pool": "^1.1.2",
"buffer": "^6.0.3",
"cwise": "^1.0.10",
"minimist": "^1.2.6",
"quote-stream": "^1.0.2",
"static-eval": "2.1.0",
"static-module": "^3.0.4",
"typedarray-pool": "^1.2.0"
}
}
190 changes: 190 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {Axis} from '@bokehjs/models/axes/axis';
import {cumulativeSum} from '../stats/array';
import {scaleToOne} from '../stats/dataTransformation';
import {
interval as hdiInterval,
data as hdiData,
} from '../stats/highestDensityInterval';
import {oneD} from '../stats/marginal';
import {mean as computeMean} from '../stats/pointStatistic';
import {interpolatePoints} from '../stats/utils';
import * as interfaces from './interfaces';

// Define the names of the figures used for this Bokeh application.
const figureNames = ['marginal', 'cumulative'];

/**
* Update the given Bokeh Axis object with the new label string. You must use this
* method to update axis strings using TypeScript, otherwise the ts compiler will throw
* a type check error.
*
* @param {Axis} axis - The Bokeh Axis object needing a new label.
* @param {string | null} label - The new label for the Bokeh Axis object.
*/
export const updateAxisLabel = (axis: Axis, label: string | null): void => {
// Type check requirement.
if ('axis_label' in axis) {
axis.axis_label = label;
}
};

/**
* Compute the following statistics for the given random variable data
*
* - lower bound for the highest density interval calculated from the marginalX;
* - mean of the rawData;
* - upper bound for the highest density interval calculated from the marginalY.
*
* @param {number[]} rawData - Raw random variable data from the model.
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the
* random variable.
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable.
* @param {number | null} [hdiProb=null] - The highest density interval probability
* value. If the default value is not overwritten, then the default HDI probability
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this
* value is the default.
* @param {string[]} [text_align=['right', 'center', 'left']] - How to align the text
* shown in the figure for the point statistics.
* @param {number[]} [x_offset=[-5, 0, 5]] - Offset values for the text along the
* x-axis.
* @param {number[]} [y_offset=[0, 10, 0]] - Offset values for the text along the
* y-axis
* @returns {interfaces.LabelsData} Object containing the computed stats.
*/
export const computeStats = (
rawData: number[],
marginalX: number[],
marginalY: number[],
hdiProb: number | null = null,
text_align: string[] = ['right', 'center', 'left'],
x_offset: number[] = [-5, 0, 5],
y_offset: number[] = [0, 10, 0],
): interfaces.LabelsData => {
// Set the default value to 0.89 if no default value has been given.
const hdiProbability = hdiProb ?? 0.89;

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(marginalX);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
const text = [
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`,
`Mean: ${mean.toFixed(3)}`,
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`,
];

return {
x: x,
y: y,
text: text,
text_align: text_align,
x_offset: x_offset,
y_offset: y_offset,
};
};

/**
* Compute data for the one-dimensional marginal diagnostic tool.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the Kernel Density Estimate (KDE).
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @returns {interfaces.Data} The marginal distribution and cumulative
* distribution calculated from the given random variable data. Point statistics are
* also calculated.
*/
export const computeData = (
data: number[],
bwFactor: number,
hdiProbability: number,
): interfaces.Data => {
const output = {} as interfaces.Data;
for (let i = 0; i < figureNames.length; i += 1) {
const figureName = figureNames[i];
output[figureName] = {} as interfaces.GlyphData;

// Compute the one-dimensional KDE and its cumulative distribution.
const distribution = oneD(data, bwFactor);
switch (figureName) {
case 'cumulative':
distribution.y = scaleToOne(cumulativeSum(distribution.y));
break;
default:
break;
}

// Compute the point statistics for the given data.
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability);

output[figureName] = {
distribution: distribution,
hdi: hdiData(data, distribution.x, distribution.y, hdiProbability),
stats: {x: stats.x, y: stats.y, text: stats.text},
labels: stats,
};
}
return output;
};

/**
* Callback used to update the Bokeh application in the notebook.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {string} rvName - The name of the random variable from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the kernel density estimate.
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @param {interfaces.Sources} sources - Bokeh sources used to render glyphs in the
* application.
* @param {interfaces.Figures} figures - Bokeh figures shown in the application.
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs.
* @returns {number} We display the value of the bandwidth used for computing the Kernel
* Density Estimate in a div, and must return that value here in order to update the
* value displayed to the user.
*/
export const update = (
data: number[],
rvName: string,
bwFactor: number,
hdiProbability: number,
sources: interfaces.Sources,
figures: interfaces.Figures,
tooltips: interfaces.Tooltips,
): number => {
const computedData = computeData(data, bwFactor, hdiProbability);
for (let i = 0; i < figureNames.length; i += 1) {
// Update all sources with new data calculated above.
const figureName = figureNames[i];
sources[figureName].distribution.data = {
x: computedData[figureName].distribution.x,
y: computedData[figureName].distribution.y,
};
sources[figureName].hdi.data = {
base: computedData[figureName].hdi.base,
lower: computedData[figureName].hdi.lower,
upper: computedData[figureName].hdi.upper,
};
sources[figureName].stats.data = computedData[figureName].stats;
sources[figureName].labels.data = computedData[figureName].labels;

// Update the axes labels.
updateAxisLabel(figures[figureName].below[0], rvName);

// Update the tooltips.
tooltips[figureName].stats.tooltips = [['', '@text']];
tooltips[figureName].distribution.tooltips = [[rvName, '@x']];
}
return computedData.marginal.distribution.bandwidth;
};
12 changes: 12 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import * as marginal1d from './callbacks';

// The CustomJS methods used by Bokeh require us to make the JavaScript available in the
// browser, which is done by defining it below.
(window as any).marginal1d = marginal1d;
Loading

0 comments on commit 91bce6a

Please sign in to comment.