Skip to content

Commit

Permalink
Adding AQSOL Dataset (#240)
Browse files Browse the repository at this point in the history
* AQSOL

* aqsol working

* remove redundant file

* remove redundant file

* Update graphs.jl

* Update graphs.jl: Fixing tests

* Update AQSOL.jl: docs

* Update AQSOL.jl: better docstring

* Update MLDatasets.jl

* Update AQSOL.jl: fix edge_index

* Update graphs.jl: update tests

* Update graphs.jl: check only random graphs

* Update graphs.md: add to docs
  • Loading branch information
rbSparky authored Aug 22, 2024
1 parent 6b7d256 commit aa55d80
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/src/datasets/graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ MLDatasets.HeteroGraph
```

```@docs
AQSOL
ChickenPox
CiteSeer
Cora
Expand Down
4 changes: 3 additions & 1 deletion src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ include("datasets/graphs/citeseer.jl")
export CiteSeer
include("datasets/graphs/karateclub.jl")
export KarateClub
include("datasets/graphs/AQSOL.jl")
export AQSOL
include("datasets/graphs/movielens.jl")
export MovieLens
include("datasets/graphs/ogbdataset.jl")
Expand Down Expand Up @@ -151,6 +153,7 @@ function __init__()
# TODO automatically find and execute __init__xxx functions

# graph
__init__aqsol()
__init__chickenpox()
__init__citeseer()
__init__cora()
Expand All @@ -166,7 +169,6 @@ function __init__()
__init__temporalbrains()
__init__windmillenergy()


# misc
__init__iris()
__init__mutagenesis()
Expand Down
100 changes: 100 additions & 0 deletions src/datasets/graphs/AQSOL.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
function __init__aqsol()
DEPNAME = "AQSOL"
LINK = "https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1"
register(DataDep(DEPNAME,
"""
Dataset: The AQSOL dataset.
Website: http://arxiv.org/abs/2003.00982
""",
LINK,
post_fetch_method = unpack))
end

struct AQSOL <: AbstractDataset
split::Symbol
metadata::Dict{String,Any}
graphs::Vector{Graph}
end

"""
AQSOL(; split=:train, dir=nothing)
The AQSOL (Aqueous Solubility) dataset from the paper
[Graph Neural Network for Predicting Aqueous Solubility of Organic Molecules](http://arxiv.org/abs/2003.00982).
The dataset contains 9,882 graphs representing small organic molecules. Each graph represents a molecule, where nodes correspond to atoms and edges to bonds. The node features represent the atomic number, and the edge features represent the bond type. The target is the aqueous solubility of the molecule, measured in mol/L.
# Arguments
- `split`: Which split of the dataset to load. Can be one of `:train`, `:val`, or `:test`. Defaults to `:train`.
- `dir`: Directory in which the dataset is in.
# Examples
```julia-repl
julia> using MLDatasets
julia> data = AQSOL()
dataset AQSOL:
split => :train
metadata => Dict{String, Any} with 1 entry
graphs => 7985-element Vector{MLDatasets.Graph}
julia> length(data)
7985
julia> g = data[1]
Graph:
num_nodes => 23
num_edges => 42
edge_index => ("42-element Vector{Int64}", "42-element Vector{Int64}")
node_data => (features = "23-element Vector{Int64}",)
edge_data => (features = "42-element Vector{Int64}",)
julia> g.num_nodes
23
julia> g.node_data.features
23-element Vector{Int64}:
0
1
1
1
1
1
julia> g.edge_index
([2, 3, 3, 4, 4, 5, 5, 6, 6, 7 … 18, 19, 19, 20, 20, 21, 20, 22, 20, 23], [3, 2, 4, 3, 5, 4, 6, 5, 7, 6 … 19, 18, 20, 19, 21, 20, 22, 20, 23, 20])
```
"""
function AQSOL(;split=:train, dir=nothing)
@assert split [:train, :val, :test]
DEPNAME = "AQSOL"
path = datafile(DEPNAME, "asqol_graph_raw/$(split).pickle", dir)
graphs = Pickle.npyload(path)
g = [create_aqsol_graph(g...) for g in graphs]
metadata = Dict{String, Any}("n_observations" => length(g))
return AQSOL(split, metadata, g)
end

function create_aqsol_graph(x, edge_attr, edge_index, y)
x = Int.(x)
edge_attr = Int.(edge_attr)
edge_index = Int.(edge_index .+ 1)

if size(edge_index, 2) == 0
s, t = Int[], Int[]
else
s, t = edge_index[1, :], edge_index[2, :]
end

return Graph(; num_nodes = length(x),
edge_index = (s, t),
node_data = (features = x,),
edge_data = (features = edge_attr,))
end

Base.length(d::AQSOL) = length(d.graphs)
Base.getindex(d::AQSOL, ::Colon) = d.graphs
Base.getindex(d::AQSOL, i) = getindex(d.graphs, i)
2 changes: 1 addition & 1 deletion src/datasets/graphs/movielens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,4 +549,4 @@ Base.length(data::MovieLens) = length(data.graphs)
function Base.getindex(data::MovieLens, ::Colon)
length(data.graphs) == 1 ? data.graphs[1] : data.graphs
end
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
20 changes: 20 additions & 0 deletions test/datasets/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,23 @@ end
@test size(g.edge_data.features) == (2, g.num_edges)
@test size(g.edge_data.targets) == (g.num_edges,)
end

@testset "AQSOL" begin
split_counts = Dict(:train => 7985, :val => 998, :test => 999)
for split in [:train, :val, :test]
data = AQSOL(split=split)
@test data isa AbstractDataset
@test data.split == split
@test length(data) == data.metadata["n_observations"]
@test length(data.graphs) == split_counts[split]

i = rand(1:length(data))
g = data[i]
@test g isa MLDatasets.Graph
s, t = g.edge_index
@test all(1 .<= s .<= g.num_nodes)
@test all(1 .<= t .<= g.num_nodes)
@test length(s) == g.num_edges
@test length(t) == g.num_edges
end
end

0 comments on commit aa55d80

Please sign in to comment.